summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events_worker.py83
1 files changed, 49 insertions, 34 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d6e146ff1..c31fc00eaa 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids.difference_update(already_fetching_ids)
 
         if missing_events_ids:
-            log_ctx = current_context()
-            log_ctx.record_event_fetch(len(missing_events_ids))
-
-            # Add entries to `self._current_event_fetches` for each event we're
-            # going to pull from the DB. We use a single deferred that resolves
-            # to all the events we pulled from the DB (this will result in this
-            # function returning more events than requested, but that can happen
-            # already due to `_get_events_from_db`).
-            fetching_deferred: ObservableDeferred[
-                Dict[str, EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-            for event_id in missing_events_ids:
-                self._current_event_fetches[event_id] = fetching_deferred
-
-            # Note that _get_events_from_db is also responsible for turning db rows
-            # into FrozenEvents (via _get_event_from_row), which involves seeing if
-            # the events have been redacted, and if so pulling the redaction event out
-            # of the database to check it.
-            #
-            try:
-                missing_events = await self._get_events_from_db(
-                    missing_events_ids,
-                )
 
-                event_entry_map.update(missing_events)
-            except Exception as e:
-                with PreserveLoggingContext():
-                    fetching_deferred.errback(e)
-                raise e
-            finally:
-                # Ensure that we mark these events as no longer being fetched.
+            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+                """Fetches the events in `missing_event_ids` from the database.
+
+                Also creates entries in `self._current_event_fetches` to allow
+                concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
+                """
+                log_ctx = current_context()
+                log_ctx.record_event_fetch(len(missing_events_ids))
+
+                # Add entries to `self._current_event_fetches` for each event we're
+                # going to pull from the DB. We use a single deferred that resolves
+                # to all the events we pulled from the DB (this will result in this
+                # function returning more events than requested, but that can happen
+                # already due to `_get_events_from_db`).
+                fetching_deferred: ObservableDeferred[
+                    Dict[str, EventCacheEntry]
+                ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
                 for event_id in missing_events_ids:
-                    self._current_event_fetches.pop(event_id, None)
+                    self._current_event_fetches[event_id] = fetching_deferred
 
-            with PreserveLoggingContext():
-                fetching_deferred.callback(missing_events)
+                # Note that _get_events_from_db is also responsible for turning db rows
+                # into FrozenEvents (via _get_event_from_row), which involves seeing if
+                # the events have been redacted, and if so pulling the redaction event
+                # out of the database to check it.
+                #
+                try:
+                    missing_events = await self._get_events_from_db(
+                        missing_events_ids,
+                    )
+                except Exception as e:
+                    with PreserveLoggingContext():
+                        fetching_deferred.errback(e)
+                    raise e
+                finally:
+                    # Ensure that we mark these events as no longer being fetched.
+                    for event_id in missing_events_ids:
+                        self._current_event_fetches.pop(event_id, None)
+
+                with PreserveLoggingContext():
+                    fetching_deferred.callback(missing_events)
+
+                return missing_events
+
+            # We must allow the database fetch to complete in the presence of
+            # cancellations, since multiple `_get_events_from_cache_or_db` calls can
+            # reuse the same fetch.
+            missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
+                get_missing_events_from_db()
+            )
+            event_entry_map.update(missing_events)
 
         if already_fetching_deferreds:
             # Wait for the other event requests to finish and add their results