summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py123
1 files changed, 54 insertions, 69 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e8330a2b50..798ed75b30 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -552,8 +552,12 @@ class FederationHandler(BaseHandler):
         destination: str,
         room_id: str,
         event_id: str,
-    ) -> Tuple[List[EventBase], List[EventBase]]:
-        """Requests all of the room state at a given event from a remote homeserver.
+    ) -> List[EventBase]:
+        """Requests all of the room state at a given event from a remote
+        homeserver.
+
+        Will also fetch any missing events reported in the `auth_chain_ids`
+        section of `/state_ids`.
 
         Args:
             destination: The remote homeserver to query for the state.
@@ -561,8 +565,7 @@ class FederationHandler(BaseHandler):
             event_id: The id of the event we want the state at.
 
         Returns:
-            A list of events in the state, not including the event itself, and
-            a list of events in the auth chain for the given event.
+            A list of events in the state, not including the event itself.
         """
         (
             state_event_ids,
@@ -571,68 +574,53 @@ class FederationHandler(BaseHandler):
             destination, room_id, event_id=event_id
         )
 
-        desired_events = set(state_event_ids + auth_event_ids)
-
-        event_map = await self._get_events_from_store_or_dest(
-            destination, room_id, desired_events
-        )
+        # Fetch the state events from the DB, and check we have the auth events.
+        event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
+        auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
 
-        failed_to_fetch = desired_events - event_map.keys()
-        if failed_to_fetch:
-            logger.warning(
-                "Failed to fetch missing state/auth events for %s %s",
-                event_id,
-                failed_to_fetch,
+        # Check for missing events. We handle state and auth event seperately,
+        # as we want to pull the state from the DB, but we don't for the auth
+        # events. (Note: we likely won't use the majority of the auth chain, and
+        # it can be *huge* for large rooms, so it's worth ensuring that we don't
+        # unnecessarily pull it from the DB).
+        missing_state_events = set(state_event_ids) - set(event_map)
+        missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
+        if missing_state_events or missing_auth_events:
+            await self._get_events_and_persist(
+                destination=destination,
+                room_id=room_id,
+                events=missing_state_events | missing_auth_events,
             )
 
-        remote_state = [
-            event_map[e_id] for e_id in state_event_ids if e_id in event_map
-        ]
-
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
-        auth_chain.sort(key=lambda e: e.depth)
-
-        return remote_state, auth_chain
-
-    async def _get_events_from_store_or_dest(
-        self, destination: str, room_id: str, event_ids: Iterable[str]
-    ) -> Dict[str, EventBase]:
-        """Fetch events from a remote destination, checking if we already have them.
-
-        Persists any events we don't already have as outliers.
-
-        If we fail to fetch any of the events, a warning will be logged, and the event
-        will be omitted from the result. Likewise, any events which turn out not to
-        be in the given room.
+            if missing_state_events:
+                new_events = await self.store.get_events(
+                    missing_state_events, allow_rejected=True
+                )
+                event_map.update(new_events)
 
-        This function *does not* automatically get missing auth events of the
-        newly fetched events. Callers must include the full auth chain of
-        of the missing events in the `event_ids` argument, to ensure that any
-        missing auth events are correctly fetched.
+                missing_state_events.difference_update(new_events)
 
-        Returns:
-            map from event_id to event
-        """
-        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
-
-        missing_events = set(event_ids) - fetched_events.keys()
+                if missing_state_events:
+                    logger.warning(
+                        "Failed to fetch missing state events for %s %s",
+                        event_id,
+                        missing_state_events,
+                    )
 
-        if missing_events:
-            logger.debug(
-                "Fetching unknown state/auth events %s for room %s",
-                missing_events,
-                room_id,
-            )
+            if missing_auth_events:
+                auth_events_in_store = await self.store.have_seen_events(
+                    missing_auth_events
+                )
+                missing_auth_events.difference_update(auth_events_in_store)
 
-            await self._get_events_and_persist(
-                destination=destination, room_id=room_id, events=missing_events
-            )
+                if missing_auth_events:
+                    logger.warning(
+                        "Failed to fetch missing auth events for %s %s",
+                        event_id,
+                        missing_auth_events,
+                    )
 
-            # we need to make sure we re-load from the database to get the rejected
-            # state correct.
-            fetched_events.update(
-                (await self.store.get_events(missing_events, allow_rejected=True))
-            )
+        remote_state = list(event_map.values())
 
         # check for events which were in the wrong room.
         #
@@ -640,8 +628,8 @@ class FederationHandler(BaseHandler):
         # auth_events at an event in room A are actually events in room B
 
         bad_events = [
-            (event_id, event.room_id)
-            for event_id, event in fetched_events.items()
+            (event.event_id, event.room_id)
+            for event in remote_state
             if event.room_id != room_id
         ]
 
@@ -658,9 +646,10 @@ class FederationHandler(BaseHandler):
                 room_id,
             )
 
-            del fetched_events[bad_event_id]
+        if bad_events:
+            remote_state = [e for e in remote_state if e.room_id == room_id]
 
-        return fetched_events
+        return remote_state
 
     async def _get_state_after_missing_prev_event(
         self,
@@ -963,27 +952,23 @@ class FederationHandler(BaseHandler):
 
         # For each edge get the current state.
 
-        auth_events = {}
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = await self._get_state_for_room(
+            state = await self._get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id,
             )
-            auth_events.update({a.event_id: a for a in auth})
-            auth_events.update({s.event_id: s for s in state})
             state_events.update({s.event_id: s for s in state})
             events_to_state[e_id] = state
 
         required_auth = {
             a_id
-            for event in events
-            + list(state_events.values())
-            + list(auth_events.values())
+            for event in events + list(state_events.values())
             for a_id in event.auth_event_ids()
         }
+        auth_events = await self.store.get_events(required_auth, allow_rejected=True)
         auth_events.update(
             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
         )