summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12529.misc1
-rw-r--r--synapse/storage/databases/main/events_worker.py83
-rw-r--r--tests/storage/databases/main/test_events_worker.py121
3 files changed, 169 insertions, 36 deletions
diff --git a/changelog.d/12529.misc b/changelog.d/12529.misc
new file mode 100644
index 0000000000..5427108742
--- /dev/null
+++ b/changelog.d/12529.misc
@@ -0,0 +1 @@
+Handle cancellation in `EventsWorkerStore._get_events_from_cache_or_db`.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d6e146ff1..c31fc00eaa 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids.difference_update(already_fetching_ids)
 
         if missing_events_ids:
-            log_ctx = current_context()
-            log_ctx.record_event_fetch(len(missing_events_ids))
-
-            # Add entries to `self._current_event_fetches` for each event we're
-            # going to pull from the DB. We use a single deferred that resolves
-            # to all the events we pulled from the DB (this will result in this
-            # function returning more events than requested, but that can happen
-            # already due to `_get_events_from_db`).
-            fetching_deferred: ObservableDeferred[
-                Dict[str, EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-            for event_id in missing_events_ids:
-                self._current_event_fetches[event_id] = fetching_deferred
-
-            # Note that _get_events_from_db is also responsible for turning db rows
-            # into FrozenEvents (via _get_event_from_row), which involves seeing if
-            # the events have been redacted, and if so pulling the redaction event out
-            # of the database to check it.
-            #
-            try:
-                missing_events = await self._get_events_from_db(
-                    missing_events_ids,
-                )
 
-                event_entry_map.update(missing_events)
-            except Exception as e:
-                with PreserveLoggingContext():
-                    fetching_deferred.errback(e)
-                raise e
-            finally:
-                # Ensure that we mark these events as no longer being fetched.
+            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+                """Fetches the events in `missing_event_ids` from the database.
+
+                Also creates entries in `self._current_event_fetches` to allow
+                concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
+                """
+                log_ctx = current_context()
+                log_ctx.record_event_fetch(len(missing_events_ids))
+
+                # Add entries to `self._current_event_fetches` for each event we're
+                # going to pull from the DB. We use a single deferred that resolves
+                # to all the events we pulled from the DB (this will result in this
+                # function returning more events than requested, but that can happen
+                # already due to `_get_events_from_db`).
+                fetching_deferred: ObservableDeferred[
+                    Dict[str, EventCacheEntry]
+                ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
                 for event_id in missing_events_ids:
-                    self._current_event_fetches.pop(event_id, None)
+                    self._current_event_fetches[event_id] = fetching_deferred
 
-            with PreserveLoggingContext():
-                fetching_deferred.callback(missing_events)
+                # Note that _get_events_from_db is also responsible for turning db rows
+                # into FrozenEvents (via _get_event_from_row), which involves seeing if
+                # the events have been redacted, and if so pulling the redaction event
+                # out of the database to check it.
+                #
+                try:
+                    missing_events = await self._get_events_from_db(
+                        missing_events_ids,
+                    )
+                except Exception as e:
+                    with PreserveLoggingContext():
+                        fetching_deferred.errback(e)
+                    raise e
+                finally:
+                    # Ensure that we mark these events as no longer being fetched.
+                    for event_id in missing_events_ids:
+                        self._current_event_fetches.pop(event_id, None)
+
+                with PreserveLoggingContext():
+                    fetching_deferred.callback(missing_events)
+
+                return missing_events
+
+            # We must allow the database fetch to complete in the presence of
+            # cancellations, since multiple `_get_events_from_cache_or_db` calls can
+            # reuse the same fetch.
+            missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
+                get_missing_events_from_db()
+            )
+            event_entry_map.update(missing_events)
 
         if already_fetching_deferreds:
             # Wait for the other event requests to finish and add their results
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 1f6a9eb07b..bf6374f93d 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 import json
 from contextlib import contextmanager
-from typing import Generator
+from typing import Generator, Tuple
+from unittest import mock
 
 from twisted.enterprise.adbapi import ConnectionPool
-from twisted.internet.defer import ensureDeferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import EventFormatVersions, RoomVersions
@@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
 
         # This next event fetch should succeed
         self.get_success(self.store.get_event(self.event_ids[0]))
+
+
+class GetEventCancellationTestCase(unittest.HomeserverTestCase):
+    """Test cancellation of `get_event` calls."""
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        self.store: EventsWorkerStore = hs.get_datastores().main
+
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+
+        self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+        res = self.helper.send(self.room, tok=self.token)
+        self.event_id = res["event_id"]
+
+        # Reset the event cache so the tests start with it empty
+        self.store._get_event_cache.clear()
+
+    @contextmanager
+    def blocking_get_event_calls(
+        self,
+    ) -> Generator[
+        Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
+    ]:
+        """Starts two concurrent `get_event` calls for the same event.
+
+        Both `get_event` calls will use the same database fetch, which will be blocked
+        at the time this function returns.
+
+        Returns:
+            A tuple containing:
+             * A `Deferred` that unblocks the database fetch.
+             * A cancellable `Deferred` for the first `get_event` call.
+             * A cancellable `Deferred` for the second `get_event` call.
+        """
+        # Patch `DatabasePool.runWithConnection` to block.
+        unblock: "Deferred[None]" = Deferred()
+        original_runWithConnection = self.store.db_pool.runWithConnection
+
+        async def runWithConnection(*args, **kwargs):
+            await unblock
+            return await original_runWithConnection(*args, **kwargs)
+
+        with mock.patch.object(
+            self.store.db_pool,
+            "runWithConnection",
+            new=runWithConnection,
+        ):
+            ctx1 = LoggingContext("get_event1")
+            ctx2 = LoggingContext("get_event2")
+
+            async def get_event(ctx: LoggingContext) -> None:
+                with ctx:
+                    await self.store.get_event(self.event_id)
+
+            get_event1 = ensureDeferred(get_event(ctx1))
+            get_event2 = ensureDeferred(get_event(ctx2))
+
+            # Both `get_event` calls ought to be blocked.
+            self.assertNoResult(get_event1)
+            self.assertNoResult(get_event2)
+
+            yield unblock, get_event1, get_event2
+
+        # Confirm that the two `get_event` calls shared the same database fetch.
+        self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
+        self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
+
+    def test_first_get_event_cancelled(self):
+        """Test cancellation of the first `get_event` call sharing a database fetch.
+
+        The first `get_event` call is the one which initiates the fetch. We expect the
+        fetch to complete despite the cancellation. Furthermore, the first `get_event`
+        call must not abort before the fetch is complete, otherwise the fetch will be
+        using a finished logging context.
+        """
+        with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+            # Cancel the first `get_event` call.
+            get_event1.cancel()
+            # The first `get_event` call must not abort immediately, otherwise its
+            # logging context will be finished while it is still in use by the database
+            # fetch.
+            self.assertNoResult(get_event1)
+            # The second `get_event` call must not be cancelled.
+            self.assertNoResult(get_event2)
+
+            # Unblock the database fetch.
+            unblock.callback(None)
+            # A `CancelledError` should be raised out of the first `get_event` call.
+            exc = self.get_failure(get_event1, CancelledError).value
+            self.assertIsInstance(exc, CancelledError)
+            # The second `get_event` call should complete successfully.
+            self.get_success(get_event2)
+
+    def test_second_get_event_cancelled(self):
+        """Test cancellation of the second `get_event` call sharing a database fetch."""
+        with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+            # Cancel the second `get_event` call.
+            get_event2.cancel()
+            # The first `get_event` call must not be cancelled.
+            self.assertNoResult(get_event1)
+            # The second `get_event` call gets cancelled immediately.
+            exc = self.get_failure(get_event2, CancelledError).value
+            self.assertIsInstance(exc, CancelledError)
+
+            # Unblock the database fetch.
+            unblock.callback(None)
+            # The first `get_event` call should complete successfully.
+            self.get_success(get_event1)