diff options
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e9ff6cfb34..0a19d1fbd3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -54,6 +54,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.logging.opentracing import start_active_span, tag_args, trace from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} + @trace + @tag_args async def get_events_as_list( self, event_ids: Collection[str], @@ -1090,15 +1093,11 @@ class EventsWorkerStore(SQLBaseStore): """ fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} - events_to_fetch = event_ids - - while events_to_fetch: - row_map = await self._enqueue_events(events_to_fetch) + async def _recursively_fetch_redactions(row_map: Dict[str, _EventRow]) -> None: # we need to recursively fetch any redactions of those events redaction_ids: Set[str] = set() - for event_id in events_to_fetch: - row = row_map.get(event_id) + for event_id, row in row_map.items(): fetched_event_ids.add(event_id) if row: fetched_events[event_id] = row @@ -1107,6 +1106,14 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = redaction_ids.difference(fetched_event_ids) if events_to_fetch: logger.debug("Also fetching redaction events %s", events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) + await _recursively_fetch_redactions(row_map) + + events_to_fetch = event_ids + row_map = await self._enqueue_events(events_to_fetch) + + with start_active_span("recursively fetching redactions"): + await _recursively_fetch_redactions(row_map) # build a map from event_id to EventBase event_map: Dict[str, EventBase] = {} @@ -1424,6 +1431,8 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} + @trace + @tag_args async def have_seen_events( self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: |