summary refs log tree commit diff
path: root/synapse/handlers/federation_event.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r--synapse/handlers/federation_event.py51
1 files changed, 46 insertions, 5 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 4bd87709f3..03c1197c99 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -224,7 +224,7 @@ class FederationEventHandler:
                     len(missing_prevs),
                     shortstr(missing_prevs),
                 )
-                with (await self._room_pdu_linearizer.queue(pdu.room_id)):
+                async with self._room_pdu_linearizer.queue(pdu.room_id):
                     logger.info(
                         "Acquired room lock to fetch %d missing prev_events",
                         len(missing_prevs),
@@ -469,6 +469,12 @@ class FederationEventHandler:
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
+            # the remote server is responsible for sending our join event to the rest
+            # of the federation. Indeed, attempting to do so will result in problems
+            # when we try to look up the state before the join (to get the server list)
+            # and discover that we do not have it.
+            event.internal_metadata.proactively_send = False
+
             return await self.persist_events_and_notify(room_id, [(event, context)])
 
     async def backfill(
@@ -891,10 +897,24 @@ class FederationEventHandler:
         logger.debug("We are also missing %i auth events", len(missing_auth_events))
 
         missing_events = missing_desired_events | missing_auth_events
-        logger.debug("Fetching %i events from remote", len(missing_events))
-        await self._get_events_and_persist(
-            destination=destination, room_id=room_id, event_ids=missing_events
-        )
+
+        # Making an individual request for each of 1000s of events has a lot of
+        # overhead. On the other hand, we don't really want to fetch all of the events
+        # if we already have most of them.
+        #
+        # As an arbitrary heuristic, if we are missing more than 10% of the events, then
+        # we fetch the whole state.
+        #
+        # TODO: might it be better to have an API which lets us do an aggregate event
+        #   request
+        if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+            logger.debug("Requesting complete state from remote")
+            await self._get_state_and_persist(destination, room_id, event_id)
+        else:
+            logger.debug("Fetching %i events from remote", len(missing_events))
+            await self._get_events_and_persist(
+                destination=destination, room_id=room_id, event_ids=missing_events
+            )
 
         # we need to make sure we re-load from the database to get the rejected
         # state correct.
@@ -953,6 +973,27 @@ class FederationEventHandler:
 
         return remote_state
 
+    async def _get_state_and_persist(
+        self, destination: str, room_id: str, event_id: str
+    ) -> None:
+        """Get the complete room state at a given event, and persist any new events
+        as outliers"""
+        room_version = await self._store.get_room_version(room_id)
+        auth_events, state_events = await self._federation_client.get_room_state(
+            destination, room_id, event_id=event_id, room_version=room_version
+        )
+        logger.info("/state returned %i events", len(auth_events) + len(state_events))
+
+        await self._auth_and_persist_outliers(
+            room_id, itertools.chain(auth_events, state_events)
+        )
+
+        # we also need the event itself.
+        if not await self._store.have_seen_event(room_id, event_id):
+            await self._get_events_and_persist(
+                destination=destination, room_id=room_id, event_ids=(event_id,)
+            )
+
     async def _process_received_pdu(
         self,
         origin: str,