diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 83 |
1 files changed, 49 insertions, 34 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d6e146ff1..c31fc00eaa 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import ( from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids.difference_update(already_fetching_ids) if missing_events_ids: - log_ctx = current_context() - log_ctx.record_event_fetch(len(missing_events_ids)) - - # Add entries to `self._current_event_fetches` for each event we're - # going to pull from the DB. We use a single deferred that resolves - # to all the events we pulled from the DB (this will result in this - # function returning more events than requested, but that can happen - # already due to `_get_events_from_db`). - fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) - for event_id in missing_events_ids: - self._current_event_fetches[event_id] = fetching_deferred - - # Note that _get_events_from_db is also responsible for turning db rows - # into FrozenEvents (via _get_event_from_row), which involves seeing if - # the events have been redacted, and if so pulling the redaction event out - # of the database to check it. - # - try: - missing_events = await self._get_events_from_db( - missing_events_ids, - ) - event_entry_map.update(missing_events) - except Exception as e: - with PreserveLoggingContext(): - fetching_deferred.errback(e) - raise e - finally: - # Ensure that we mark these events as no longer being fetched. + async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: + """Fetches the events in `missing_event_ids` from the database. + + Also creates entries in `self._current_event_fetches` to allow + concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch. + """ + log_ctx = current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, EventCacheEntry] + ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) for event_id in missing_events_ids: - self._current_event_fetches.pop(event_id, None) + self._current_event_fetches[event_id] = fetching_deferred - with PreserveLoggingContext(): - fetching_deferred.callback(missing_events) + # Note that _get_events_from_db is also responsible for turning db rows + # into FrozenEvents (via _get_event_from_row), which involves seeing if + # the events have been redacted, and if so pulling the redaction event + # out of the database to check it. + # + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + return missing_events + + # We must allow the database fetch to complete in the presence of + # cancellations, since multiple `_get_events_from_cache_or_db` calls can + # reuse the same fetch. + missing_events: Dict[str, EventCacheEntry] = await delay_cancellation( + get_missing_events_from_db() + ) + event_entry_map.update(missing_events) if already_fetching_deferreds: # Wait for the other event requests to finish and add their results |