summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9601.feature1
-rw-r--r--synapse/handlers/federation.py152
-rw-r--r--synapse/storage/databases/main/events_worker.py12
3 files changed, 137 insertions, 28 deletions
diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature
new file mode 100644
index 0000000000..5078d63ffa
--- /dev/null
+++ b/changelog.d/9601.feature
@@ -0,0 +1 @@
+Optimise handling of incomplete room history for incoming federation.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1d20c441f3..598a66f74c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -353,17 +353,16 @@ class FederationHandler(BaseHandler):
                     # Ask the remote server for the states we don't
                     # know about
                     for p in prevs - seen:
-                        logger.info(
-                            "Requesting state at missing prev_event %s",
-                            event_id,
-                        )
+                        logger.info("Requesting state after missing prev_event %s", p)
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            (remote_state, _,) = await self._get_state_for_room(
-                                origin, room_id, p, include_event_in_state=True
+                            remote_state = (
+                                await self._get_state_after_missing_prev_event(
+                                    origin, room_id, p
+                                )
                             )
 
                             remote_state_map = {
@@ -539,7 +538,6 @@ class FederationHandler(BaseHandler):
         destination: str,
         room_id: str,
         event_id: str,
-        include_event_in_state: bool = False,
     ) -> Tuple[List[EventBase], List[EventBase]]:
         """Requests all of the room state at a given event from a remote homeserver.
 
@@ -547,11 +545,9 @@ class FederationHandler(BaseHandler):
             destination: The remote homeserver to query for the state.
             room_id: The id of the room we're interested in.
             event_id: The id of the event we want the state at.
-            include_event_in_state: if true, the event itself will be included in the
-                returned state event list.
 
         Returns:
-            A list of events in the state, possibly including the event itself, and
+            A list of events in the state, not including the event itself, and
             a list of events in the auth chain for the given event.
         """
         (
@@ -563,9 +559,6 @@ class FederationHandler(BaseHandler):
 
         desired_events = set(state_event_ids + auth_event_ids)
 
-        if include_event_in_state:
-            desired_events.add(event_id)
-
         event_map = await self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
@@ -582,13 +575,6 @@ class FederationHandler(BaseHandler):
             event_map[e_id] for e_id in state_event_ids if e_id in event_map
         ]
 
-        if include_event_in_state:
-            remote_event = event_map.get(event_id)
-            if not remote_event:
-                raise Exception("Unable to get missing prev_event %s" % (event_id,))
-            if remote_event.is_state() and remote_event.rejected_reason is None:
-                remote_state.append(remote_event)
-
         auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
         auth_chain.sort(key=lambda e: e.depth)
 
@@ -662,6 +648,131 @@ class FederationHandler(BaseHandler):
 
         return fetched_events
 
+    async def _get_state_after_missing_prev_event(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+    ) -> List[EventBase]:
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
+
+        Returns:
+            A list of events in the state, including the event itself
+        """
+        # TODO: This function is basically the same as _get_state_for_room. Can
+        #   we make backfill() use it, rather than having two code paths? I think the
+        #   only difference is that backfill() persists the prev events separately.
+
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = await self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        logger.debug(
+            "state_ids returned %i state events, %i auth events",
+            len(state_event_ids),
+            len(auth_event_ids),
+        )
+
+        # start by just trying to fetch the events from the store
+        desired_events = set(state_event_ids)
+        desired_events.add(event_id)
+        logger.debug("Fetching %i events from cache/store", len(desired_events))
+        fetched_events = await self.store.get_events(
+            desired_events, allow_rejected=True
+        )
+
+        missing_desired_events = desired_events - fetched_events.keys()
+        logger.debug(
+            "We are missing %i events (got %i)",
+            len(missing_desired_events),
+            len(fetched_events),
+        )
+
+        # We probably won't need most of the auth events, so let's just check which
+        # we have for now, rather than thrashing the event cache with them all
+        # unnecessarily.
+
+        # TODO: we probably won't actually need all of the auth events, since we
+        #   already have a bunch of the state events. It would be nice if the
+        #   federation api gave us a way of finding out which we actually need.
+
+        missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+        missing_auth_events.difference_update(
+            await self.store.have_seen_events(missing_auth_events)
+        )
+        logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+        missing_events = missing_desired_events | missing_auth_events
+        logger.debug("Fetching %i events from remote", len(missing_events))
+        await self._get_events_and_persist(
+            destination=destination, room_id=room_id, events=missing_events
+        )
+
+        # we need to make sure we re-load from the database to get the rejected
+        # state correct.
+        fetched_events.update(
+            (await self.store.get_events(missing_desired_events, allow_rejected=True))
+        )
+
+        # check for events which were in the wrong room.
+        #
+        # this can happen if a remote server claims that the state or
+        # auth_events at an event in room A are actually events in room B
+
+        bad_events = [
+            (event_id, event.room_id)
+            for event_id, event in fetched_events.items()
+            if event.room_id != room_id
+        ]
+
+        for bad_event_id, bad_room_id in bad_events:
+            # This is a bogus situation, but since we may only discover it a long time
+            # after it happened, we try our best to carry on, by just omitting the
+            # bad events from the returned state set.
+            logger.warning(
+                "Remote server %s claims event %s in room %s is an auth/state "
+                "event in room %s",
+                destination,
+                bad_event_id,
+                bad_room_id,
+                room_id,
+            )
+
+            del fetched_events[bad_event_id]
+
+        # if we couldn't get the prev event in question, that's a problem.
+        remote_event = fetched_events.get(event_id)
+        if not remote_event:
+            raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+        # missing state at that event is a warning, not a blocker
+        # XXX: this doesn't sound right? it means that we'll end up with incomplete
+        #   state.
+        failed_to_fetch = desired_events - fetched_events.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state events for %s %s",
+                event_id,
+                failed_to_fetch,
+            )
+
+        remote_state = [
+            fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+        ]
+
+        if remote_event.is_state() and remote_event.rejected_reason is None:
+            remote_state.append(remote_event)
+
+        return remote_state
+
     async def _process_received_pdu(
         self,
         origin: str,
@@ -841,7 +952,6 @@ class FederationHandler(BaseHandler):
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id,
-                include_event_in_state=False,
             )
             auth_events.update({a.event_id: a for a in auth})
             auth_events.update({s.event_id: s for s in state})
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index edbe42f2bf..c04e162ccc 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import itertools
+
 import logging
 import threading
 from collections import namedtuple
@@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             set[str]: The events we have already seen.
         """
-        results = set()
+        # if the event cache contains the event, obviously we've seen it.
+        results = {x for x in event_ids if self._get_event_cache.contains(x)}
 
         def have_seen_events_txn(txn, chunk):
             sql = "SELECT event_id FROM events as e WHERE "
@@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore):
                 txn.database_engine, "e.event_id", chunk
             )
             txn.execute(sql + clause, args)
-            for (event_id,) in txn:
-                results.add(event_id)
+            results.update(row[0] for row in txn)
 
-        # break the input up into chunks of 100
-        input_iterator = iter(event_ids)
-        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+        for chunk in batch_iter((x for x in event_ids if x not in results), 100):
             await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )