diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d6e146ff1..c31fc00eaa 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
- log_ctx = current_context()
- log_ctx.record_event_fetch(len(missing_events_ids))
-
- # Add entries to `self._current_event_fetches` for each event we're
- # going to pull from the DB. We use a single deferred that resolves
- # to all the events we pulled from the DB (this will result in this
- # function returning more events than requested, but that can happen
- # already due to `_get_events_from_db`).
- fetching_deferred: ObservableDeferred[
- Dict[str, EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
- for event_id in missing_events_ids:
- self._current_event_fetches[event_id] = fetching_deferred
-
- # Note that _get_events_from_db is also responsible for turning db rows
- # into FrozenEvents (via _get_event_from_row), which involves seeing if
- # the events have been redacted, and if so pulling the redaction event out
- # of the database to check it.
- #
- try:
- missing_events = await self._get_events_from_db(
- missing_events_ids,
- )
- event_entry_map.update(missing_events)
- except Exception as e:
- with PreserveLoggingContext():
- fetching_deferred.errback(e)
- raise e
- finally:
- # Ensure that we mark these events as no longer being fetched.
+ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+ """Fetches the events in `missing_event_ids` from the database.
+
+ Also creates entries in `self._current_event_fetches` to allow
+ concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
+ """
+ log_ctx = current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ # Add entries to `self._current_event_fetches` for each event we're
+ # going to pull from the DB. We use a single deferred that resolves
+ # to all the events we pulled from the DB (this will result in this
+ # function returning more events than requested, but that can happen
+ # already due to `_get_events_from_db`).
+ fetching_deferred: ObservableDeferred[
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
- self._current_event_fetches.pop(event_id, None)
+ self._current_event_fetches[event_id] = fetching_deferred
- with PreserveLoggingContext():
- fetching_deferred.callback(missing_events)
+ # Note that _get_events_from_db is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event
+ # out of the database to check it.
+ #
+ try:
+ missing_events = await self._get_events_from_db(
+ missing_events_ids,
+ )
+ except Exception as e:
+ with PreserveLoggingContext():
+ fetching_deferred.errback(e)
+ raise e
+ finally:
+ # Ensure that we mark these events as no longer being fetched.
+ for event_id in missing_events_ids:
+ self._current_event_fetches.pop(event_id, None)
+
+ with PreserveLoggingContext():
+ fetching_deferred.callback(missing_events)
+
+ return missing_events
+
+ # We must allow the database fetch to complete in the presence of
+ # cancellations, since multiple `_get_events_from_cache_or_db` calls can
+ # reuse the same fetch.
+ missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
+ get_missing_events_from_db()
+ )
+ event_entry_map.update(missing_events)
if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
|