diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 6e5f1cf6ee..e15e7d86fe 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -17,7 +17,6 @@ from __future__ import division
import itertools
import logging
-import operator
from collections import namedtuple
from canonicaljson import json
@@ -30,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
-from synapse.logging.context import (
- LoggingContext,
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
-)
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
@@ -468,39 +462,49 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
- map from event id to result.
+ map from event id to result. May return extra events which
+ weren't asked for.
"""
- if not event_ids:
- return {}
+ fetched_events = {}
+ events_to_fetch = event_ids
- row_map = yield self._enqueue_events(event_ids)
+ while events_to_fetch:
+ row_map = yield self._enqueue_events(events_to_fetch)
- rows = (row_map.get(event_id) for event_id in event_ids)
+ # we need to recursively fetch any redactions of those events
+ redaction_ids = set()
+ for event_id in events_to_fetch:
+ row = row_map.get(event_id)
+ fetched_events[event_id] = row
+ if row:
+ redaction_ids.update(row["redactions"])
- # filter out absent rows
- rows = filter(operator.truth, rows)
+ events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ if events_to_fetch:
+ logger.debug("Also fetching redaction events %s", events_to_fetch)
- if not allow_rejected:
- rows = (r for r in rows if r["rejected_reason"] is None)
+ result_map = {}
+ for event_id, row in fetched_events.items():
+ if not row:
+ continue
+ assert row["event_id"] == event_id
- res = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._get_event_from_row,
- row["internal_metadata"],
- row["json"],
- row["redactions"],
- rejected_reason=row["rejected_reason"],
- format_version=row["format_version"],
- )
- for row in rows
- ],
- consumeErrors=True,
+ rejected_reason = row["rejected_reason"]
+
+ if not allow_rejected and rejected_reason:
+ continue
+
+ cache_entry = yield self._get_event_from_row(
+ row["internal_metadata"],
+ row["json"],
+ row["redactions"],
+ rejected_reason=row["rejected_reason"],
+ format_version=row["format_version"],
)
- )
- return {e.event.event_id: e for e in res if e}
+ result_map[event_id] = cache_entry
+
+ return result_map
@defer.inlineCallbacks
def _enqueue_events(self, events):
|