diff --git a/changelog.d/12529.misc b/changelog.d/12529.misc
new file mode 100644
index 0000000000..5427108742
--- /dev/null
+++ b/changelog.d/12529.misc
@@ -0,0 +1 @@
+Handle cancellation in `EventsWorkerStore._get_events_from_cache_or_db`.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d6e146ff1..c31fc00eaa 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
- log_ctx = current_context()
- log_ctx.record_event_fetch(len(missing_events_ids))
-
- # Add entries to `self._current_event_fetches` for each event we're
- # going to pull from the DB. We use a single deferred that resolves
- # to all the events we pulled from the DB (this will result in this
- # function returning more events than requested, but that can happen
- # already due to `_get_events_from_db`).
- fetching_deferred: ObservableDeferred[
- Dict[str, EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
- for event_id in missing_events_ids:
- self._current_event_fetches[event_id] = fetching_deferred
-
- # Note that _get_events_from_db is also responsible for turning db rows
- # into FrozenEvents (via _get_event_from_row), which involves seeing if
- # the events have been redacted, and if so pulling the redaction event out
- # of the database to check it.
- #
- try:
- missing_events = await self._get_events_from_db(
- missing_events_ids,
- )
- event_entry_map.update(missing_events)
- except Exception as e:
- with PreserveLoggingContext():
- fetching_deferred.errback(e)
- raise e
- finally:
- # Ensure that we mark these events as no longer being fetched.
+ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+ """Fetches the events in `missing_event_ids` from the database.
+
+ Also creates entries in `self._current_event_fetches` to allow
+ concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
+ """
+ log_ctx = current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ # Add entries to `self._current_event_fetches` for each event we're
+ # going to pull from the DB. We use a single deferred that resolves
+ # to all the events we pulled from the DB (this will result in this
+ # function returning more events than requested, but that can happen
+ # already due to `_get_events_from_db`).
+ fetching_deferred: ObservableDeferred[
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
- self._current_event_fetches.pop(event_id, None)
+ self._current_event_fetches[event_id] = fetching_deferred
- with PreserveLoggingContext():
- fetching_deferred.callback(missing_events)
+ # Note that _get_events_from_db is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event
+ # out of the database to check it.
+ #
+ try:
+ missing_events = await self._get_events_from_db(
+ missing_events_ids,
+ )
+ except Exception as e:
+ with PreserveLoggingContext():
+ fetching_deferred.errback(e)
+ raise e
+ finally:
+ # Ensure that we mark these events as no longer being fetched.
+ for event_id in missing_events_ids:
+ self._current_event_fetches.pop(event_id, None)
+
+ with PreserveLoggingContext():
+ fetching_deferred.callback(missing_events)
+
+ return missing_events
+
+ # We must allow the database fetch to complete in the presence of
+ # cancellations, since multiple `_get_events_from_cache_or_db` calls can
+ # reuse the same fetch.
+ missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
+ get_missing_events_from_db()
+ )
+ event_entry_map.update(missing_events)
if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 1f6a9eb07b..bf6374f93d 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -13,10 +13,11 @@
# limitations under the License.
import json
from contextlib import contextmanager
-from typing import Generator
+from typing import Generator, Tuple
+from unittest import mock
from twisted.enterprise.adbapi import ConnectionPool
-from twisted.internet.defer import ensureDeferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import EventFormatVersions, RoomVersions
@@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
# This next event fetch should succeed
self.get_success(self.store.get_event(self.event_ids[0]))
+
+
+class GetEventCancellationTestCase(unittest.HomeserverTestCase):
+ """Test cancellation of `get_event` calls."""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.store: EventsWorkerStore = hs.get_datastores().main
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login(self.user, "pass")
+
+ self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+ res = self.helper.send(self.room, tok=self.token)
+ self.event_id = res["event_id"]
+
+ # Reset the event cache so the tests start with it empty
+ self.store._get_event_cache.clear()
+
+ @contextmanager
+ def blocking_get_event_calls(
+ self,
+ ) -> Generator[
+ Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
+ ]:
+ """Starts two concurrent `get_event` calls for the same event.
+
+ Both `get_event` calls will use the same database fetch, which will be blocked
+ at the time this function returns.
+
+ Returns:
+ A tuple containing:
+ * A `Deferred` that unblocks the database fetch.
+ * A cancellable `Deferred` for the first `get_event` call.
+ * A cancellable `Deferred` for the second `get_event` call.
+ """
+ # Patch `DatabasePool.runWithConnection` to block.
+ unblock: "Deferred[None]" = Deferred()
+ original_runWithConnection = self.store.db_pool.runWithConnection
+
+ async def runWithConnection(*args, **kwargs):
+ await unblock
+ return await original_runWithConnection(*args, **kwargs)
+
+ with mock.patch.object(
+ self.store.db_pool,
+ "runWithConnection",
+ new=runWithConnection,
+ ):
+ ctx1 = LoggingContext("get_event1")
+ ctx2 = LoggingContext("get_event2")
+
+ async def get_event(ctx: LoggingContext) -> None:
+ with ctx:
+ await self.store.get_event(self.event_id)
+
+ get_event1 = ensureDeferred(get_event(ctx1))
+ get_event2 = ensureDeferred(get_event(ctx2))
+
+ # Both `get_event` calls ought to be blocked.
+ self.assertNoResult(get_event1)
+ self.assertNoResult(get_event2)
+
+ yield unblock, get_event1, get_event2
+
+ # Confirm that the two `get_event` calls shared the same database fetch.
+ self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
+ self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
+
+ def test_first_get_event_cancelled(self):
+ """Test cancellation of the first `get_event` call sharing a database fetch.
+
+ The first `get_event` call is the one which initiates the fetch. We expect the
+ fetch to complete despite the cancellation. Furthermore, the first `get_event`
+ call must not abort before the fetch is complete, otherwise the fetch will be
+ using a finished logging context.
+ """
+ with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+ # Cancel the first `get_event` call.
+ get_event1.cancel()
+ # The first `get_event` call must not abort immediately, otherwise its
+ # logging context will be finished while it is still in use by the database
+ # fetch.
+ self.assertNoResult(get_event1)
+ # The second `get_event` call must not be cancelled.
+ self.assertNoResult(get_event2)
+
+ # Unblock the database fetch.
+ unblock.callback(None)
+ # A `CancelledError` should be raised out of the first `get_event` call.
+ exc = self.get_failure(get_event1, CancelledError).value
+ self.assertIsInstance(exc, CancelledError)
+ # The second `get_event` call should complete successfully.
+ self.get_success(get_event2)
+
+ def test_second_get_event_cancelled(self):
+ """Test cancellation of the second `get_event` call sharing a database fetch."""
+ with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+ # Cancel the second `get_event` call.
+ get_event2.cancel()
+ # The first `get_event` call must not be cancelled.
+ self.assertNoResult(get_event1)
+ # The second `get_event` call gets cancelled immediately.
+ exc = self.get_failure(get_event2, CancelledError).value
+ self.assertIsInstance(exc, CancelledError)
+
+ # Unblock the database fetch.
+ unblock.callback(None)
+ # The first `get_event` call should complete successfully.
+ self.get_success(get_event1)
|