summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-08-02 14:37:25 +0100
committerGitHub <noreply@github.com>2021-08-02 13:37:25 +0000
commit01d45fe964d323e7f66358c2db57d00a44bf2274 (patch)
tree738e1198533a9f6e9a29df6c5aa42e7c9ee4ee49
parentAllow setting transaction limit for db connections (#10440) (diff)
downloadsynapse-01d45fe964d323e7f66358c2db57d00a44bf2274.tar.xz
Prune inbound federation queues if they get too long (#10390)
-rw-r--r--changelog.d/10390.misc1
-rw-r--r--synapse/federation/federation_server.py17
-rw-r--r--synapse/storage/databases/main/event_federation.py104
-rw-r--r--tests/storage/test_event_federation.py57
4 files changed, 177 insertions, 2 deletions
diff --git a/changelog.d/10390.misc b/changelog.d/10390.misc
new file mode 100644
index 0000000000..911a5733ee
--- /dev/null
+++ b/changelog.d/10390.misc
@@ -0,0 +1 @@
+Prune inbound federation inbound queues for a room if they get too large.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2892a11d7d..145b9161d9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1024,6 +1024,23 @@ class FederationServer(FederationBase):
 
             origin, event = next
 
+            # Prune the event queue if it's getting large.
+            #
+            # We do this *after* handling the first event as the common case is
+            # that the queue is empty (/has the single event in), and so there's
+            # no need to do this check.
+            pruned = await self.store.prune_staged_events_in_room(room_id, room_version)
+            if pruned:
+                # If we have pruned the queue check we need to refetch the next
+                # event to handle.
+                next = await self.store.get_next_staged_event_for_room(
+                    room_id, room_version
+                )
+                if not next:
+                    break
+
+                origin, event = next
+
             lock = await self.store.try_acquire_lock(
                 _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
             )
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 547e43ab98..44018c1c31 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,11 +16,11 @@ import logging
 from queue import Empty, PriorityQueue
 from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
 
-from prometheus_client import Gauge
+from prometheus_client import Counter, Gauge
 
 from synapse.api.constants import MAX_DEPTH
 from synapse.api.errors import StoreError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -44,6 +44,12 @@ number_pdus_in_federation_queue = Gauge(
     "The total number of events in the inbound federation staging",
 )
 
+pdus_pruned_from_federation_queue = Counter(
+    "synapse_federation_server_number_inbound_pdu_pruned",
+    "The number of events in the inbound federation staging that have been "
+    "pruned due to the queue getting too long",
+)
+
 logger = logging.getLogger(__name__)
 
 
@@ -1277,6 +1283,100 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return origin, event
 
+    async def prune_staged_events_in_room(
+        self,
+        room_id: str,
+        room_version: RoomVersion,
+    ) -> bool:
+        """Checks if there are lots of staged events for the room, and if so
+        prune them down.
+
+        Returns:
+            Whether any events were pruned
+        """
+
+        # First check the size of the queue.
+        count = await self.db_pool.simple_select_one_onecol(
+            table="federation_inbound_events_staging",
+            keyvalues={"room_id": room_id},
+            retcol="COALESCE(COUNT(*), 0)",
+            desc="prune_staged_events_in_room_count",
+        )
+
+        if count < 100:
+            return False
+
+        # If the queue is too large, then we want clear the entire queue,
+        # keeping only the forward extremities (i.e. the events not referenced
+        # by other events in the queue). We do this so that we can always
+        # backpaginate in all the events we have dropped.
+        rows = await self.db_pool.simple_select_list(
+            table="federation_inbound_events_staging",
+            keyvalues={"room_id": room_id},
+            retcols=("event_id", "event_json"),
+            desc="prune_staged_events_in_room_fetch",
+        )
+
+        # Find the set of events referenced by those in the queue, as well as
+        # collecting all the event IDs in the queue.
+        referenced_events: Set[str] = set()
+        seen_events: Set[str] = set()
+        for row in rows:
+            event_id = row["event_id"]
+            seen_events.add(event_id)
+            event_d = db_to_json(row["event_json"])
+
+            # We don't bother parsing the dicts into full blown event objects,
+            # as that is needlessly expensive.
+
+            # We haven't checked that the `prev_events` have the right format
+            # yet, so we check as we go.
+            prev_events = event_d.get("prev_events", [])
+            if not isinstance(prev_events, list):
+                logger.info("Invalid prev_events for %s", event_id)
+                continue
+
+            if room_version.event_format == EventFormatVersions.V1:
+                for prev_event_tuple in prev_events:
+                    if not isinstance(prev_event_tuple, list) or len(prev_events) != 2:
+                        logger.info("Invalid prev_events for %s", event_id)
+                        break
+
+                    prev_event_id = prev_event_tuple[0]
+                    if not isinstance(prev_event_id, str):
+                        logger.info("Invalid prev_events for %s", event_id)
+                        break
+
+                    referenced_events.add(prev_event_id)
+            else:
+                for prev_event_id in prev_events:
+                    if not isinstance(prev_event_id, str):
+                        logger.info("Invalid prev_events for %s", event_id)
+                        break
+
+                    referenced_events.add(prev_event_id)
+
+        to_delete = referenced_events & seen_events
+        if not to_delete:
+            return False
+
+        pdus_pruned_from_federation_queue.inc(len(to_delete))
+        logger.info(
+            "Pruning %d events in room %s from federation queue",
+            len(to_delete),
+            room_id,
+        )
+
+        await self.db_pool.simple_delete_many(
+            table="federation_inbound_events_staging",
+            keyvalues={"room_id": room_id},
+            iterable=to_delete,
+            column="event_id",
+            desc="prune_staged_events_in_room_delete",
+        )
+
+        return True
+
     async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
         """Get the room IDs of all events currently staged."""
         return await self.db_pool.simple_select_onecol(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index a0e2259478..c3fcf7e7b4 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -15,7 +15,9 @@
 import attr
 from parameterized import parameterized
 
+from synapse.api.room_versions import RoomVersions
 from synapse.events import _EventInternalMetadata
+from synapse.util import json_encoder
 
 import tests.unittest
 import tests.utils
@@ -504,6 +506,61 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertSetEqual(difference, set())
 
+    def test_prune_inbound_federation_queue(self):
+        "Test that pruning of inbound federation queues work"
+
+        room_id = "some_room_id"
+
+        # Insert a bunch of events that all reference the previous one.
+        self.get_success(
+            self.store.db_pool.simple_insert_many(
+                table="federation_inbound_events_staging",
+                values=[
+                    {
+                        "origin": "some_origin",
+                        "room_id": room_id,
+                        "received_ts": 0,
+                        "event_id": f"$fake_event_id_{i + 1}",
+                        "event_json": json_encoder.encode(
+                            {"prev_events": [f"$fake_event_id_{i}"]}
+                        ),
+                        "internal_metadata": "{}",
+                    }
+                    for i in range(500)
+                ],
+                desc="test_prune_inbound_federation_queue",
+            )
+        )
+
+        # Calling prune once should return True, i.e. a prune happen. The second
+        # time it shouldn't.
+        pruned = self.get_success(
+            self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+        )
+        self.assertTrue(pruned)
+
+        pruned = self.get_success(
+            self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+        )
+        self.assertFalse(pruned)
+
+        # Assert that we only have a single event left in the queue, and that it
+        # is the last one.
+        count = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                table="federation_inbound_events_staging",
+                keyvalues={"room_id": room_id},
+                retcol="COALESCE(COUNT(*), 0)",
+                desc="test_prune_inbound_federation_queue",
+            )
+        )
+        self.assertEqual(count, 1)
+
+        _, event_id = self.get_success(
+            self.store.get_next_staged_event_id_for_room(room_id)
+        )
+        self.assertEqual(event_id, "$fake_event_id_500")
+
 
 @attr.s
 class FakeEvent: