summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py151
1 files changed, 105 insertions, 46 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 2ccdc3bfa7..45d955e6f5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -18,7 +18,6 @@
 
 import itertools
 import logging
-import sys
 
 import six
 from six import iteritems, itervalues
@@ -106,7 +105,7 @@ class FederationHandler(BaseHandler):
 
         self.hs = hs
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -323,14 +322,22 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
-                # Calculate the state of the previous events, and
-                # de-conflict them to find the current state.
-                state_groups = []
+                # Calculate the state after each of the previous events, and
+                # resolve them to find the correct state at the current event.
                 auth_chains = set()
+                event_map = {
+                    event_id: pdu,
+                }
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.store.get_state_groups(room_id, list(seen))
-                    state_groups.append(ours)
+                    ours = yield self.store.get_state_groups_ids(room_id, seen)
+
+                    # state_maps is a list of mappings from (type, state_key) to event_id
+                    # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())
+
+                    # we don't need this any more, let's delete it.
+                    del ours
 
                     # Ask the remote server for the states we don't
                     # know about
@@ -339,27 +346,65 @@ class FederationHandler(BaseHandler):
                             "[%s %s] Requesting state at missing prev_event %s",
                             room_id, event_id, p,
                         )
-                        state, got_auth_chain = (
-                            yield self.federation_client.get_state_for_room(
-                                origin, room_id, p,
+
+                        with logcontext.nested_logging_context(p):
+                            # note that if any of the missing prevs share missing state or
+                            # auth events, the requests to fetch those events are deduped
+                            # by the get_pdu_cache in federation_client.
+                            remote_state, got_auth_chain = (
+                                yield self.federation_client.get_state_for_room(
+                                    origin, room_id, p,
+                                )
                             )
-                        )
-                        auth_chains.update(got_auth_chain)
-                        state_group = {(x.type, x.state_key): x.event_id for x in state}
-                        state_groups.append(state_group)
+
+                            # we want the state *after* p; get_state_for_room returns the
+                            # state *before* p.
+                            remote_event = yield self.federation_client.get_pdu(
+                                [origin], p, outlier=True,
+                            )
+
+                            if remote_event is None:
+                                raise Exception(
+                                    "Unable to get missing prev_event %s" % (p, )
+                                )
+
+                            if remote_event.is_state():
+                                remote_state.append(remote_event)
+
+                            # XXX hrm I'm not convinced that duplicate events will compare
+                            # for equality, so I'm not sure this does what the author
+                            # hoped.
+                            auth_chains.update(got_auth_chain)
+
+                            remote_state_map = {
+                                (x.type, x.state_key): x.event_id for x in remote_state
+                            }
+                            state_maps.append(remote_state_map)
+
+                            for x in remote_state:
+                                event_map[x.event_id] = x
 
                     # Resolve any conflicting state
+                    @defer.inlineCallbacks
                     def fetch(ev_ids):
-                        return self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False
+                        fetched = yield self.store.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
                         )
+                        # add any events we fetch here to the `event_map` so that we
+                        # can use them to build the state event list below.
+                        event_map.update(fetched)
+                        defer.returnValue(fetched)
 
                     room_version = yield self.store.get_room_version(room_id)
                     state_map = yield resolve_events_with_factory(
-                        room_version, state_groups, {event_id: pdu}, fetch
+                        room_version, state_maps, event_map, fetch,
                     )
 
-                    state = (yield self.store.get_events(state_map.values())).values()
+                    # we need to give _process_received_pdu the actual state events
+                    # rather than event ids, so generate that now.
+                    state = [
+                        event_map[e] for e in six.itervalues(state_map)
+                    ]
                     auth_chain = list(auth_chains)
                 except Exception:
                     logger.warn(
@@ -483,20 +528,21 @@ class FederationHandler(BaseHandler):
                 "[%s %s] Handling received prev_event %s",
                 room_id, event_id, ev.event_id,
             )
-            try:
-                yield self.on_receive_pdu(
-                    origin,
-                    ev,
-                    sent_to_us_directly=False,
-                )
-            except FederationError as e:
-                if e.code == 403:
-                    logger.warn(
-                        "[%s %s] Received prev_event %s failed history check.",
-                        room_id, event_id, ev.event_id,
+            with logcontext.nested_logging_context(ev.event_id):
+                try:
+                    yield self.on_receive_pdu(
+                        origin,
+                        ev,
+                        sent_to_us_directly=False,
                     )
-                else:
-                    raise
+                except FederationError as e:
+                    if e.code == 403:
+                        logger.warn(
+                            "[%s %s] Received prev_event %s failed history check.",
+                            room_id, event_id, ev.event_id,
+                        )
+                    else:
+                        raise
 
     @defer.inlineCallbacks
     def _process_received_pdu(self, origin, event, state, auth_chain):
@@ -572,6 +618,10 @@ class FederationHandler(BaseHandler):
                     })
                     seen_ids.add(e.event_id)
 
+                logger.info(
+                    "[%s %s] persisting newly-received auth/state events %s",
+                    room_id, event_id, [e["event"].event_id for e in event_infos]
+                )
                 yield self._handle_new_events(origin, event_infos)
 
             try:
@@ -1135,7 +1185,8 @@ class FederationHandler(BaseHandler):
             try:
                 logger.info("Processing queued PDU %s which was received "
                             "while we were joining %s", p.event_id, p.room_id)
-                yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                with logcontext.nested_logging_context(p.event_id):
+                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warn(
                     "Error handling queued PDU %s from %s: %s",
@@ -1550,6 +1601,9 @@ class FederationHandler(BaseHandler):
             auth_events=auth_events,
         )
 
+        # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
+        # hack around with a try/finally instead.
+        success = False
         try:
             if not event.internal_metadata.is_outlier() and not backfilled:
                 yield self.action_generator.handle_push_actions_for_event(
@@ -1560,15 +1614,13 @@ class FederationHandler(BaseHandler):
                 [(event, context)],
                 backfilled=backfilled,
             )
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            tp, value, tb = sys.exc_info()
-
-            logcontext.run_in_background(
-                self.store.remove_push_actions_from_staging,
-                event.event_id,
-            )
-
-            six.reraise(tp, value, tb)
+            success = True
+        finally:
+            if not success:
+                logcontext.run_in_background(
+                    self.store.remove_push_actions_from_staging,
+                    event.event_id,
+                )
 
         defer.returnValue(context)
 
@@ -1581,15 +1633,22 @@ class FederationHandler(BaseHandler):
 
         Notifies about the events where appropriate.
         """
-        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                logcontext.run_in_background(
-                    self._prep_event,
+
+        @defer.inlineCallbacks
+        def prep(ev_info):
+            event = ev_info["event"]
+            with logcontext.nested_logging_context(suffix=event.event_id):
+                res = yield self._prep_event(
                     origin,
-                    ev_info["event"],
+                    event,
                     state=ev_info.get("state"),
                     auth_events=ev_info.get("auth_events"),
                 )
+            defer.returnValue(res)
+
+        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+            [
+                logcontext.run_in_background(prep, ev_info)
                 for ev_info in event_infos
             ], consumeErrors=True,
         ))