diff options
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 134 |
1 files changed, 95 insertions, 39 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6c0bc7eafa..618cb53629 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -124,7 +124,7 @@ class FederationHandler(BaseHandler): try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -335,31 +335,58 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - if missing_auth: + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: logger.info("Missing auth for backfill: %r", missing_auth) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch + ) + + results = yield defer.gatherResults( + [ + self.replication_layer.get_pdu( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results}) + required_auth.update( + a_id for event in results for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): @@ -372,6 +399,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -383,6 +411,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) @@ -637,7 +666,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -688,7 +717,9 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create join %r because %s", event, e) raise e - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) defer.returnValue(event) @@ -918,7 +949,9 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -987,14 +1020,9 @@ class FederationHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1114,11 +1142,12 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id - ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1150,11 +1179,19 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` @@ -1167,7 +1204,7 @@ class FederationHandler(BaseHandler): event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1176,10 +1213,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event @@ -1413,7 +1469,7 @@ class FederationHandler(BaseHandler): local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( |