diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 874d0a56bc..6d6bb1d5d3 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -218,37 +218,23 @@ class EventsWorkerStore(SQLBaseStore):
if not event_ids:
defer.returnValue([])
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids, allow_rejected=allow_rejected
+ # there may be duplicates so we cast the list to a set
+ event_entry_map = yield self._get_events_from_cache_or_db(
+ set(event_ids), allow_rejected=allow_rejected
)
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- log_ctx = LoggingContext.current_context()
- log_ctx.record_event_fetch(len(missing_events_ids))
-
- # Note that _enqueue_events is also responsible for turning db rows
- # into FrozenEvents (via _get_event_from_row), which involves seeing if
- # the events have been redacted, and if so pulling the redaction event out
- # of the database to check it.
- #
- # _enqueue_events is a bit of a rubbish name but naming is hard.
- missing_events = yield self._enqueue_events(
- missing_events_ids, allow_rejected=allow_rejected
- )
-
- event_entry_map.update(missing_events)
-
events = []
- for event_id in event_id_list:
+ for event_id in event_ids:
entry = event_entry_map.get(event_id, None)
if not entry:
continue
+ if not allow_rejected:
+ assert not entry.event.rejected_reason, (
+ "rejected event returned from _get_events_from_cache_or_db despite "
+ "allow_rejected=False"
+ )
+
# Starting in room version v3, some redactions need to be rechecked if we
# didn't have the redacted event at the time, so we recheck on read
# instead.
@@ -291,34 +277,74 @@ class EventsWorkerStore(SQLBaseStore):
# recheck.
entry.event.internal_metadata.recheck_redaction = False
else:
- # We don't have the event that is being redacted, so we
- # assume that the event isn't authorized for now. (If we
- # later receive the event, then we will always redact
- # it anyway, since we have this redaction)
+ # We either don't have the event that is being redacted (so we
+ # assume that the event isn't authorised for now), or the
+ # senders don't match (so it will never be authorised). Either
+ # way, we shouldn't return it.
+ #
+ # (If we later receive the event, then we will redact it anyway,
+ # since we have this redaction)
continue
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
defer.returnValue(events)
+ @defer.inlineCallbacks
+ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the cache or the database.
+
+ If events are pulled from the database, they will be cached for future lookups.
+
+ Args:
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+ allow_rejected (bool): Whether to include rejected events
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result
+ """
+ event_entry_map = self._get_events_from_cache(
+ event_ids, allow_rejected=allow_rejected
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ # Note that _enqueue_events is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event out
+ # of the database to check it.
+ #
+ # _enqueue_events is a bit of a rubbish name but naming is hard.
+ missing_events = yield self._enqueue_events(
+ missing_events_ids, allow_rejected=allow_rejected
+ )
+
+ event_entry_map.update(missing_events)
+
+ return event_entry_map
+
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
@@ -326,7 +352,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Fetch events from the caches
Args:
- events (list(str)): list of event_ids to fetch
+ events (Iterable[str]): list of event_ids to fetch
allow_rejected (bool): Whether to return events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
|