diff options
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 342 |
1 files changed, 283 insertions, 59 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 85e2757227..46ce3699d7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -18,9 +18,11 @@ from ._base import BaseHandler from synapse.api.errors import ( - AuthError, FederationError, StoreError, + AuthError, FederationError, StoreError, CodeMessageException, SynapseError, ) from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.util import unwrapFirstError +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -29,6 +31,8 @@ from synapse.crypto.event_signing import ( ) from synapse.types import UserID +from synapse.util.retryutils import NotRetryingDestination + from twisted.internet import defer import itertools @@ -156,7 +160,7 @@ class FederationHandler(BaseHandler): ) try: - yield self._handle_new_event( + _, event_stream_id, max_stream_id = yield self._handle_new_event( origin, event, state=state, @@ -197,9 +201,11 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -218,37 +224,210 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit): + def backfill(self, dest, room_id, limit, extremities=[]): """ Trigger a backfill request to `dest` for the given `room_id` """ - extremities = yield self.store.get_oldest_events_in_room(room_id) + if not extremities: + extremities = yield self.store.get_oldest_events_in_room(room_id) - pdus = yield self.replication_layer.backfill( + events = yield self.replication_layer.backfill( dest, room_id, - limit, + limit=limit, extremities=extremities, ) - events = [] + event_map = {e.event_id: e for e in events} - for pdu in pdus: - event = pdu + event_ids = set(e.event_id for e in events) - # FIXME (erikj): Not sure this actually works :/ - context = yield self.state_handler.compute_event_context(event) + edges = [ + ev.event_id + for ev in events + if set(e_id for e_id, _ in ev.prev_events) - event_ids + ] - events.append((event, context)) + # For each edge get the current state. - yield self.store.persist_event( - event, - context=context, - backfilled=True + auth_events = {} + events_to_state = {} + for e_id in edges: + state, auth = yield self.replication_layer.get_state_for_room( + destination=dest, + room_id=room_id, + event_id=e_id + ) + auth_events.update({a.event_id: a for a in auth}) + events_to_state[e_id] = state + + yield defer.gatherResults( + [ + self._handle_new_event(dest, a) + for a in auth_events.values() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + yield defer.gatherResults( + [ + self._handle_new_event( + dest, event_map[e_id], + state=events_to_state[e_id], + backfilled=True, + ) + for e_id in events_to_state + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + + events.sort(key=lambda e: e.depth) + + for event in events: + if event in events_to_state: + continue + + yield self._handle_new_event( + dest, event, + backfilled=True, ) defer.returnValue(events) @defer.inlineCallbacks + def maybe_backfill(self, room_id, current_depth): + """Checks the database to see if we should backfill before paginating, + and if so do. + """ + extremities = yield self.store.get_oldest_events_with_depth_in_room( + room_id + ) + + if not extremities: + logger.debug("Not backfilling as no extremeties found.") + return + + # Check if we reached a point where we should start backfilling. + sorted_extremeties_tuple = sorted( + extremities.items(), + key=lambda e: -int(e[1]) + ) + max_depth = sorted_extremeties_tuple[0][1] + + if current_depth > max_depth: + logger.debug( + "Not backfilling as we don't need to. %d < %d", + max_depth, current_depth, + ) + return + + # Now we need to decide which hosts to hit first. + + # First we try hosts that are already in the room + # TODO: HEURISTIC ALERT. + + curr_state = yield self.state_handler.get_current_state(room_id) + + def get_domains_from_state(state): + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member + and event.membership == Membership.JOIN + ] + + joined_domains = {} + for u, d in joined_users: + try: + dom = UserID.from_string(u).domain + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + curr_domains = get_domains_from_state(curr_state) + + likely_domains = [ + domain for domain, depth in curr_domains + if domain is not self.server_name + ] + + @defer.inlineCallbacks + def try_backfill(domains): + # TODO: Should we try multiple of these at a time? + for dom in domains: + try: + events = yield self.backfill( + dom, room_id, + limit=100, + extremities=[e for e in extremities.keys()] + ) + except SynapseError: + logger.info( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + except CodeMessageException as e: + if 400 <= e.code < 500: + raise + + logger.info( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + except NotRetryingDestination as e: + logger.info(e.message) + continue + except Exception as e: + logger.exception( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + + if events: + defer.returnValue(True) + defer.returnValue(False) + + success = yield try_backfill(likely_domains) + if success: + defer.returnValue(True) + + # Huh, well *those* domains didn't work out. Lets try some domains + # from the time. + + tried_domains = set(likely_domains) + tried_domains.add(self.server_name) + + event_ids = list(extremities.keys()) + + states = yield defer.gatherResults([ + self.state_handler.resolve_state_groups([e]) + for e in event_ids + ]) + states = dict(zip(event_ids, [s[1] for s in states])) + + for e_id, _ in sorted_extremeties_tuple: + likely_domains = get_domains_from_state(states[e_id]) + + success = yield try_backfill([ + dom for dom in likely_domains + if dom not in tried_domains + ]) + if success: + defer.returnValue(True) + + tried_domains.update(likely_domains) + + defer.returnValue(False) + + @defer.inlineCallbacks def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -376,30 +555,14 @@ class FederationHandler(BaseHandler): # FIXME pass - for e in auth_chain: - e.internal_metadata.outlier = True - - if e.event_id == event.event_id: - continue - - try: - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - yield self._handle_new_event( - origin, e, auth_events=auth - ) - except: - logger.exception( - "Failed to handle auth event %s", - e.event_id, - ) + yield self._handle_auth_events( + origin, [e for e in auth_chain if e.event_id != event.event_id] + ) - for e in state: + @defer.inlineCallbacks + def handle_state(e): if e.event_id == event.event_id: - continue + return e.internal_metadata.outlier = True try: @@ -417,13 +580,15 @@ class FederationHandler(BaseHandler): e.event_id, ) + yield defer.DeferredList([handle_state(e) for e in state]) + auth_ids = [e_id for e_id, _ in event.auth_events] auth_events = { (e.type, e.state_key): e for e in auth_chain if e.event_id in auth_ids } - yield self._handle_new_event( + _, event_stream_id, max_stream_id = yield self._handle_new_event( origin, new_event, state=state, @@ -431,9 +596,11 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - d = self.notifier.on_new_room_event( - new_event, extra_users=[joinee] - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + new_event, event_stream_id, max_stream_id, + extra_users=[joinee] + ) def log_failure(f): logger.warn( @@ -498,7 +665,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(origin, event) + context, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, event + ) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -512,9 +681,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -587,16 +757,18 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event) - yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=False, ) target_user = UserID.from_string(event.state_key) - d = self.notifier.on_new_room_event( - event, extra_users=[target_user], - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[target_user], + ) def log_failure(f): logger.warn( @@ -745,9 +917,12 @@ class FederationHandler(BaseHandler): # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: - if len(event.prev_events) == 1: - c = yield self.store.get_event(event.prev_events[0][0]) - if c.type == EventTypes.Create: + if len(event.prev_events) == 1 and event.depth < 5: + c = yield self.store.get_event( + event.prev_events[0][0], + allow_none=True, + ) + if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c try: @@ -773,7 +948,7 @@ class FederationHandler(BaseHandler): ) raise - yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=backfilled, @@ -781,7 +956,7 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - defer.returnValue(context) + defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, @@ -921,7 +1096,7 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) @@ -1166,3 +1341,52 @@ class FederationHandler(BaseHandler): }, "missing": [e.event_id for e in missing_locals], }) + + @defer.inlineCallbacks + def _handle_auth_events(self, origin, auth_events): + auth_ids_to_deferred = {} + + def process_auth_ev(ev): + auth_ids = [e_id for e_id, _ in ev.auth_events] + + prev_ds = [ + auth_ids_to_deferred[i] + for i in auth_ids + if i in auth_ids_to_deferred + ] + + d = defer.Deferred() + + auth_ids_to_deferred[ev.event_id] = d + + @defer.inlineCallbacks + def f(*_): + ev.internal_metadata.outlier = True + + try: + auth = { + (e.type, e.state_key): e for e in auth_events + if e.event_id in auth_ids + } + + yield self._handle_new_event( + origin, ev, auth_events=auth + ) + except: + logger.exception( + "Failed to handle auth event %s", + ev.event_id, + ) + + d.callback(None) + + if prev_ds: + dx = defer.DeferredList(prev_ds) + dx.addBoth(f) + else: + f() + + for e in auth_events: + process_auth_ev(e) + + yield defer.DeferredList(auth_ids_to_deferred.values()) |