summary refs log tree commit diff
path: root/synapse/storage/events_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events_worker.py')
-rw-r--r--synapse/storage/events_worker.py68
1 files changed, 36 insertions, 32 deletions
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 6e5f1cf6ee..e15e7d86fe 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -17,7 +17,6 @@ from __future__ import division
 
 import itertools
 import logging
-import operator
 from collections import namedtuple
 
 from canonicaljson import json
@@ -30,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
 from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event
-from synapse.logging.context import (
-    LoggingContext,
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    run_in_background,
-)
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
@@ -468,39 +462,49 @@ class EventsWorkerStore(SQLBaseStore):
 
         Returns:
             Deferred[Dict[str, _EventCacheEntry]]:
-                map from event id to result.
+                map from event id to result. May return extra events which
+                weren't asked for.
         """
-        if not event_ids:
-            return {}
+        fetched_events = {}
+        events_to_fetch = event_ids
 
-        row_map = yield self._enqueue_events(event_ids)
+        while events_to_fetch:
+            row_map = yield self._enqueue_events(events_to_fetch)
 
-        rows = (row_map.get(event_id) for event_id in event_ids)
+            # we need to recursively fetch any redactions of those events
+            redaction_ids = set()
+            for event_id in events_to_fetch:
+                row = row_map.get(event_id)
+                fetched_events[event_id] = row
+                if row:
+                    redaction_ids.update(row["redactions"])
 
-        # filter out absent rows
-        rows = filter(operator.truth, rows)
+            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            if events_to_fetch:
+                logger.debug("Also fetching redaction events %s", events_to_fetch)
 
-        if not allow_rejected:
-            rows = (r for r in rows if r["rejected_reason"] is None)
+        result_map = {}
+        for event_id, row in fetched_events.items():
+            if not row:
+                continue
+            assert row["event_id"] == event_id
 
-        res = yield make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self._get_event_from_row,
-                        row["internal_metadata"],
-                        row["json"],
-                        row["redactions"],
-                        rejected_reason=row["rejected_reason"],
-                        format_version=row["format_version"],
-                    )
-                    for row in rows
-                ],
-                consumeErrors=True,
+            rejected_reason = row["rejected_reason"]
+
+            if not allow_rejected and rejected_reason:
+                continue
+
+            cache_entry = yield self._get_event_from_row(
+                row["internal_metadata"],
+                row["json"],
+                row["redactions"],
+                rejected_reason=row["rejected_reason"],
+                format_version=row["format_version"],
             )
-        )
 
-        return {e.event.event_id: e for e in res if e}
+            result_map[event_id] = cache_entry
+
+        return result_map
 
     @defer.inlineCallbacks
     def _enqueue_events(self, events):