summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-08-11 14:58:07 -0500
committerEric Eastwood <erice@element.io>2022-08-11 14:58:07 -0500
commit477fad64b8795683063803b5e532520744acb34d (patch)
treef9d1bf3a96aa2e5789733a3cb4a105d962c76d5a /synapse/storage/databases
parentTune buckets (diff)
downloadsynapse-477fad64b8795683063803b5e532520744acb34d.tar.xz
Refactor recursive code so we can wrap just the redaction part
See https://github.com/matrix-org/synapse/pull/13489#discussion_r943444260
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events_worker.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e9ff6cfb34..0a19d1fbd3 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -1090,15 +1093,11 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
-
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
 
+        async def _recursively_fetch_redactions(row_map: Dict[str, _EventRow]) -> None:
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
-                row = row_map.get(event_id)
+            for event_id, row in row_map.items():
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
@@ -1107,6 +1106,14 @@ class EventsWorkerStore(SQLBaseStore):
             events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
+                row_map = await self._enqueue_events(events_to_fetch)
+                await _recursively_fetch_redactions(row_map)
+
+        events_to_fetch = event_ids
+        row_map = await self._enqueue_events(events_to_fetch)
+
+        with start_active_span("recursively fetching redactions"):
+            await _recursively_fetch_redactions(row_map)
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]: