summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10303.bugfix1
-rw-r--r--synapse/federation/federation_server.py67
-rw-r--r--synapse/storage/databases/main/event_federation.py9
3 files changed, 67 insertions, 10 deletions
diff --git a/changelog.d/10303.bugfix b/changelog.d/10303.bugfix
new file mode 100644
index 0000000000..c0577c9f73
--- /dev/null
+++ b/changelog.d/10303.bugfix
@@ -0,0 +1 @@
+Ensure that inbound events from federation that were being processed when Synapse was restarted get promptly processed on start up.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index b312d0b809..bf67d0f574 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -148,6 +148,41 @@ class FederationServer(FederationBase):
 
         self._room_prejoin_state_types = hs.config.api.room_prejoin_state
 
+        # Whether we have started handling old events in the staging area.
+        self._started_handling_of_staged_events = False
+
+    @wrap_as_background_process("_handle_old_staged_events")
+    async def _handle_old_staged_events(self) -> None:
+        """Handle old staged events by fetching all rooms that have staged
+        events and start the processing of each of those rooms.
+        """
+
+        # Get all the rooms IDs with staged events.
+        room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
+
+        # We then shuffle them so that if there are multiple instances doing
+        # this work they're less likely to collide.
+        random.shuffle(room_ids)
+
+        for room_id in room_ids:
+            room_version = await self.store.get_room_version(room_id)
+
+            # Try and acquire the processing lock for the room, if we get it start a
+            # background process for handling the events in the room.
+            lock = await self.store.try_acquire_lock(
+                _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
+            )
+            if lock:
+                logger.info("Handling old staged inbound events in %s", room_id)
+                self._process_incoming_pdus_in_room_inner(
+                    room_id,
+                    room_version,
+                    lock,
+                )
+
+            # We pause a bit so that we don't start handling all rooms at once.
+            await self._clock.sleep(random.uniform(0, 0.1))
+
     async def on_backfill_request(
         self, origin: str, room_id: str, versions: List[str], limit: int
     ) -> Tuple[int, Dict[str, Any]]:
@@ -166,6 +201,12 @@ class FederationServer(FederationBase):
     async def on_incoming_transaction(
         self, origin: str, transaction_data: JsonDict
     ) -> Tuple[int, Dict[str, Any]]:
+        # If we receive a transaction we should make sure that kick off handling
+        # any old events in the staging area.
+        if not self._started_handling_of_staged_events:
+            self._started_handling_of_staged_events = True
+            self._handle_old_staged_events()
+
         # keep this as early as possible to make the calculated origin ts as
         # accurate as possible.
         request_time = self._clock.time_msec()
@@ -882,25 +923,28 @@ class FederationServer(FederationBase):
         room_id: str,
         room_version: RoomVersion,
         lock: Lock,
-        latest_origin: str,
-        latest_event: EventBase,
+        latest_origin: Optional[str] = None,
+        latest_event: Optional[EventBase] = None,
     ) -> None:
         """Process events in the staging area for the given room.
 
         The latest_origin and latest_event args are the latest origin and event
-        received.
+        received (or None to simply pull the next event from the database).
         """
 
         # The common path is for the event we just received be the only event in
         # the room, so instead of pulling the event out of the DB and parsing
         # the event we just pull out the next event ID and check if that matches.
-        next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room(
-            room_id
-        )
-        if next_origin == latest_origin and next_event_id == latest_event.event_id:
-            origin = latest_origin
-            event = latest_event
-        else:
+        if latest_event is not None and latest_origin is not None:
+            (
+                next_origin,
+                next_event_id,
+            ) = await self.store.get_next_staged_event_id_for_room(room_id)
+            if next_origin != latest_origin or next_event_id != latest_event.event_id:
+                latest_origin = None
+                latest_event = None
+
+        if latest_origin is None or latest_event is None:
             next = await self.store.get_next_staged_event_for_room(
                 room_id, room_version
             )
@@ -908,6 +952,9 @@ class FederationServer(FederationBase):
                 return
 
             origin, event = next
+        else:
+            origin = latest_origin
+            event = latest_event
 
         # We loop round until there are no more events in the room in the
         # staging area, or we fail to get the lock (which means another process
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 08d75b0d41..c4474df975 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1207,6 +1207,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return origin, event
 
+    async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
+        """Get the room IDs of all events currently staged."""
+        return await self.db_pool.simple_select_onecol(
+            table="federation_inbound_events_staging",
+            keyvalues={},
+            retcol="DISTINCT room_id",
+            desc="get_all_rooms_with_staged_incoming_events",
+        )
+
     @wrap_as_background_process("_get_stats_for_federation_staging")
     async def _get_stats_for_federation_staging(self):
         """Update the prometheus metrics for the inbound federation staging area."""