summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12886.misc1
-rw-r--r--synapse/storage/databases/main/events_worker.py42
2 files changed, 25 insertions, 18 deletions
diff --git a/changelog.d/12886.misc b/changelog.d/12886.misc
new file mode 100644
index 0000000000..3dd08f74ba
--- /dev/null
+++ b/changelog.d/12886.misc
@@ -0,0 +1 @@
+Refactor `have_seen_events` to reduce memory consumed when processing federation traffic.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5b22d6b452..a97d7e1664 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1356,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             The set of events we have already seen.
         """
-        res = await self._have_seen_events_dict(
-            (room_id, event_id) for event_id in event_ids
-        )
-        return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+        # @cachedList chomps lots of memory if you call it with a big list, so
+        # we break it down. However, each batch requires its own index scan, so we make
+        # the batches as big as possible.
+
+        results: Set[str] = set()
+        for chunk in batch_iter(event_ids, 500):
+            r = await self._have_seen_events_dict(
+                [(room_id, event_id) for event_id in chunk]
+            )
+            results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
+
+        return results
 
     @cachedList(cached_method_name="have_seen_event", list_name="keys")
     async def _have_seen_events_dict(
-        self, keys: Iterable[Tuple[str, str]]
+        self, keys: Collection[Tuple[str, str]]
     ) -> Dict[Tuple[str, str], bool]:
         """Helper for have_seen_events
 
@@ -1375,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore):
         cache_results = {
             (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
         }
-        results = {x: True for x in cache_results}
+        results = dict.fromkeys(cache_results, True)
+        remaining = [k for k in keys if k not in cache_results]
+        if not remaining:
+            return results
 
-        def have_seen_events_txn(
-            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
-        ) -> None:
+        def have_seen_events_txn(txn: LoggingTransaction) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1387,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
+                txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
             )
             txn.execute(sql + clause, args)
             found_events = {eid for eid, in txn}
 
-            # ... and then we can update the results for each row in the batch
-            results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
-
-        # each batch requires its own index scan, so we make the batches as big as
-        # possible.
-        for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
-            await self.db_pool.runInteraction(
-                "have_seen_events", have_seen_events_txn, chunk
+            # ... and then we can update the results for each key
+            results.update(
+                {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
             )
 
+        await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
         return results
 
     @cached(max_entries=100000, tree=True)