diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 648a505e65..ff6bb475b5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -66,10 +66,6 @@ class FederationHandler(BaseHandler):
self.hs = hs
- self.distributor.observe("user_joined_room", self.user_joined_room)
-
- self.waiting_for_join_list = {}
-
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler()
@@ -128,7 +124,7 @@ class FederationHandler(BaseHandler):
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
except AuthError as e:
raise FederationError(
@@ -253,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
- domain = UserID.from_string(ev.state_key).domain
+ domain = get_domain_from_id(ev.state_key)
except:
continue
@@ -339,29 +335,58 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- seen_events = yield self.store.have_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- all_events = events + state_events.values() + auth_events.values()
required_auth = set(
- a_id for event in all_events for a_id, _ in event.auth_events
+ a_id
+ for event in events + state_events.values() + auth_events.values()
+ for a_id, _ in event.auth_events
)
-
+ auth_events.update({
+ e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
+ })
missing_auth = required_auth - set(auth_events)
- results = yield defer.gatherResults(
- [
- self.replication_layer.get_pdu(
- [dest],
- event_id,
- outlier=True,
- timeout=10000,
+ failed_to_fetch = set()
+
+ # Try and fetch any missing auth events from both DB and remote servers.
+ # We repeatedly do this until we stop finding new auth events.
+ while missing_auth - failed_to_fetch:
+ logger.info("Missing auth for backfill: %r", missing_auth)
+ ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+ auth_events.update(ret_events)
+
+ required_auth.update(
+ a_id for event in ret_events.values() for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ if missing_auth - failed_to_fetch:
+ logger.info(
+ "Fetching missing auth for backfill: %r",
+ missing_auth - failed_to_fetch
)
- for event_id in missing_auth
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results})
+
+ results = yield defer.gatherResults(
+ [
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth - failed_to_fetch
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
+ required_auth.update(
+ a_id for event in results for a_id, _ in event.auth_events
+ )
+ missing_auth = required_auth - set(auth_events)
+
+ failed_to_fetch = missing_auth - set(auth_events)
+
+ seen_events = yield self.store.have_events(
+ set(auth_events.keys()) | set(state_events.keys())
+ )
ev_infos = []
for a in auth_events.values():
@@ -374,6 +399,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
+ if a_id in auth_events
}
})
@@ -385,6 +411,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
+ if a_id in auth_events
}
})
@@ -403,7 +430,7 @@ class FederationHandler(BaseHandler):
# previous to work out the state.
# TODO: We can probably do something more clever here.
yield self._handle_new_event(
- dest, event
+ dest, event, backfilled=True,
)
defer.returnValue(events)
@@ -639,7 +666,7 @@ class FederationHandler(BaseHandler):
pass
event_stream_id, max_stream_id = yield self._persist_auth_tree(
- auth_chain, state, event
+ origin, auth_chain, state, event
)
with PreserveLoggingContext():
@@ -690,7 +717,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e)
raise e
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_join_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event)
@@ -920,7 +949,9 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(event, auth_events=context.current_state)
+ # The remote hasn't signed it yet, obviously. We'll do the full checks
+ # when we get the event back in `on_send_leave_request`
+ self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
- def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
+ def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor()
- if do_auth:
- in_room = yield self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
@@ -1020,13 +1046,16 @@ class FederationHandler(BaseHandler):
res = results.values()
for event in res:
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ # We sign these again because there was a bug where we
+ # incorrectly signed things the first time round
+ if self.hs.is_mine_id(event.event_id):
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
defer.returnValue(res)
else:
@@ -1064,16 +1093,17 @@ class FederationHandler(BaseHandler):
)
if event:
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ if self.hs.is_mine_id(event.event_id):
+ # FIXME: This is a temporary work around where we occasionally
+ # return events slightly differently than when they were
+ # originally signed
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
if do_auth:
in_room = yield self.auth.check_host_in_room(
@@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
+ events = yield self._filter_events_for_server(
+ origin, event.room_id, [event]
+ )
+
+ event = events[0]
+
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler):
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
- @log_function
- def user_joined_room(self, user, room_id):
- waiters = self.waiting_for_join_list.get(
- (user.to_string(), room_id),
- []
- )
- while waiters:
- waiters.pop().callback(None)
-
@defer.inlineCallbacks
@log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None,
@@ -1122,11 +1149,12 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
- # this intentionally does not yield: we don't care about the result
- # and don't need to wait for it.
- preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
- event_stream_id, max_stream_id
- )
+ if not backfilled:
+ # this intentionally does not yield: we don't care about the result
+ # and don't need to wait for it.
+ preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+ event_stream_id, max_stream_id
+ )
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _persist_auth_tree(self, auth_events, state, event):
+ def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
+ Will attempt to fetch missing auth events.
+
+ Args:
+ origin (str): Where the events came from
+ auth_events (list)
+ state (list)
+ event (Event)
+
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
@@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler):
event_map = {
e.event_id: e
- for e in auth_events
+ for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler):
create_event = e
break
+ missing_auth_events = set()
+ for e in itertools.chain(auth_events, state, [event]):
+ for e_id, _ in e.auth_events:
+ if e_id not in event_map:
+ missing_auth_events.add(e_id)
+
+ for e_id in missing_auth_events:
+ m_ev = yield self.replication_layer.get_pdu(
+ [origin],
+ e_id,
+ outlier=True,
+ timeout=10000,
+ )
+ if m_ev and m_ev.event_id == e_id:
+ event_map[e_id] = m_ev
+ else:
+ logger.info("Failed to find auth event %r", e_id)
+
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
+ if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
- (d.type, d.state_key): d for d in different_events
+ (d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
|