summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-09-22 17:36:54 -0500
committerEric Eastwood <erice@element.io>2022-09-22 17:36:54 -0500
commit0cdc7bf43253a06ae68b131e89f71f7f1578a1ec (patch)
tree2f36e595a7fa68edfacac60aabc5855ea402abff /synapse
parentInvalidate cache like #13796 (diff)
downloadsynapse-0cdc7bf43253a06ae68b131e89f71f7f1578a1ec.tar.xz
Fix `@cachedList` on `_have_seen_events_dict`
As mentioned by @erikjohnston,
https://github.com/matrix-org/synapse/issues/13865#issuecomment-1254751569
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/cache.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py40
2 files changed, 23 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2dbf93a4e5..2c421151c1 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -223,7 +223,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         # process triggering the invalidation is responsible for clearing any external
         # cached objects.
         self._invalidate_local_get_event_cache(event_id)
-        self.have_seen_event.invalidate(((room_id, event_id),))
+        self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 52914febf9..7cdc9fe98f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
         # the batches as big as possible.
 
         results: Set[str] = set()
-        for chunk in batch_iter(event_ids, 500):
-            r = await self._have_seen_events_dict(
-                [(room_id, event_id) for event_id in chunk]
+        for event_ids_chunk in batch_iter(event_ids, 500):
+            events_seen_dict = await self._have_seen_events_dict(
+                room_id, event_ids_chunk
+            )
+            results.update(
+                eid for (eid, have_event) in events_seen_dict.items() if have_event
             )
-            results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
 
         return results
 
-    @cachedList(cached_method_name="have_seen_event", list_name="keys")
+    @cachedList(cached_method_name="have_seen_event", list_name="event_ids")
     async def _have_seen_events_dict(
-        self, keys: Collection[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], bool]:
+        self,
+        room_id: str,
+        event_ids: Collection[str],
+    ) -> Dict[str, bool]:
         """Helper for have_seen_events
 
         Returns:
-             a dict {(room_id, event_id)-> bool}
+             a dict {event_id -> bool}
         """
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid)
-            for (rid, eid) in keys
-            if await self._get_event_cache.contains((eid,))
+            event_id
+            for event_id in event_ids
+            if await self._get_event_cache.contains((event_id,))
         }
         results = dict.fromkeys(cache_results, True)
-        remaining = [k for k in keys if k not in cache_results]
+        remaining = [
+            event_id for event_id in event_ids if event_id not in cache_results
+        ]
         if not remaining:
             return results
 
@@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
+                txn.database_engine, "e.event_id", remaining
             )
             txn.execute(sql + clause, args)
             found_events = {eid for eid, in txn}
 
             # ... and then we can update the results for each key
-            results.update(
-                {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
-            )
+            results.update({eid: (eid in found_events) for eid in remaining})
 
         await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
         return results
 
     @cached(max_entries=100000, tree=True)
     async def have_seen_event(self, room_id: str, event_id: str) -> bool:
-        res = await self._have_seen_events_dict(((room_id, event_id),))
-        return res[(room_id, event_id)]
+        res = await self._have_seen_events_dict(room_id, [event_id])
+        return res[event_id]
 
     def _get_current_state_event_counts_txn(
         self, txn: LoggingTransaction, room_id: str