diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 52914febf9..7cdc9fe98f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
# the batches as big as possible.
results: Set[str] = set()
- for chunk in batch_iter(event_ids, 500):
- r = await self._have_seen_events_dict(
- [(room_id, event_id) for event_id in chunk]
+ for event_ids_chunk in batch_iter(event_ids, 500):
+ events_seen_dict = await self._have_seen_events_dict(
+ room_id, event_ids_chunk
+ )
+ results.update(
+ eid for (eid, have_event) in events_seen_dict.items() if have_event
)
- results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
return results
- @cachedList(cached_method_name="have_seen_event", list_name="keys")
+ @cachedList(cached_method_name="have_seen_event", list_name="event_ids")
async def _have_seen_events_dict(
- self, keys: Collection[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], bool]:
+ self,
+ room_id: str,
+ event_ids: Collection[str],
+ ) -> Dict[str, bool]:
"""Helper for have_seen_events
Returns:
- a dict {(room_id, event_id)-> bool}
+ a dict {event_id -> bool}
"""
# if the event cache contains the event, obviously we've seen it.
cache_results = {
- (rid, eid)
- for (rid, eid) in keys
- if await self._get_event_cache.contains((eid,))
+ event_id
+ for event_id in event_ids
+ if await self._get_event_cache.contains((event_id,))
}
results = dict.fromkeys(cache_results, True)
- remaining = [k for k in keys if k not in cache_results]
+ remaining = [
+ event_id for event_id in event_ids if event_id not in cache_results
+ ]
if not remaining:
return results
@@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
+ txn.database_engine, "e.event_id", remaining
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
# ... and then we can update the results for each key
- results.update(
- {(rid, eid): (eid in found_events) for (rid, eid) in remaining}
- )
+ results.update({eid: (eid in found_events) for eid in remaining})
await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
- res = await self._have_seen_events_dict(((room_id, event_id),))
- return res[(room_id, event_id)]
+ res = await self._have_seen_events_dict(room_id, [event_id])
+ return res[event_id]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
|