summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_events_worker.py121
1 files changed, 119 insertions, 2 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 1f6a9eb07b..bf6374f93d 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 import json
 from contextlib import contextmanager
-from typing import Generator
+from typing import Generator, Tuple
+from unittest import mock
 
 from twisted.enterprise.adbapi import ConnectionPool
-from twisted.internet.defer import ensureDeferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import EventFormatVersions, RoomVersions
@@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
 
         # This next event fetch should succeed
         self.get_success(self.store.get_event(self.event_ids[0]))
+
+
+class GetEventCancellationTestCase(unittest.HomeserverTestCase):
+    """Test cancellation of `get_event` calls."""
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        self.store: EventsWorkerStore = hs.get_datastores().main
+
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+
+        self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+        res = self.helper.send(self.room, tok=self.token)
+        self.event_id = res["event_id"]
+
+        # Reset the event cache so the tests start with it empty
+        self.store._get_event_cache.clear()
+
+    @contextmanager
+    def blocking_get_event_calls(
+        self,
+    ) -> Generator[
+        Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
+    ]:
+        """Starts two concurrent `get_event` calls for the same event.
+
+        Both `get_event` calls will use the same database fetch, which will be blocked
+        at the time this function returns.
+
+        Returns:
+            A tuple containing:
+             * A `Deferred` that unblocks the database fetch.
+             * A cancellable `Deferred` for the first `get_event` call.
+             * A cancellable `Deferred` for the second `get_event` call.
+        """
+        # Patch `DatabasePool.runWithConnection` to block.
+        unblock: "Deferred[None]" = Deferred()
+        original_runWithConnection = self.store.db_pool.runWithConnection
+
+        async def runWithConnection(*args, **kwargs):
+            await unblock
+            return await original_runWithConnection(*args, **kwargs)
+
+        with mock.patch.object(
+            self.store.db_pool,
+            "runWithConnection",
+            new=runWithConnection,
+        ):
+            ctx1 = LoggingContext("get_event1")
+            ctx2 = LoggingContext("get_event2")
+
+            async def get_event(ctx: LoggingContext) -> None:
+                with ctx:
+                    await self.store.get_event(self.event_id)
+
+            get_event1 = ensureDeferred(get_event(ctx1))
+            get_event2 = ensureDeferred(get_event(ctx2))
+
+            # Both `get_event` calls ought to be blocked.
+            self.assertNoResult(get_event1)
+            self.assertNoResult(get_event2)
+
+            yield unblock, get_event1, get_event2
+
+        # Confirm that the two `get_event` calls shared the same database fetch.
+        self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
+        self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
+
+    def test_first_get_event_cancelled(self):
+        """Test cancellation of the first `get_event` call sharing a database fetch.
+
+        The first `get_event` call is the one which initiates the fetch. We expect the
+        fetch to complete despite the cancellation. Furthermore, the first `get_event`
+        call must not abort before the fetch is complete, otherwise the fetch will be
+        using a finished logging context.
+        """
+        with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+            # Cancel the first `get_event` call.
+            get_event1.cancel()
+            # The first `get_event` call must not abort immediately, otherwise its
+            # logging context will be finished while it is still in use by the database
+            # fetch.
+            self.assertNoResult(get_event1)
+            # The second `get_event` call must not be cancelled.
+            self.assertNoResult(get_event2)
+
+            # Unblock the database fetch.
+            unblock.callback(None)
+            # A `CancelledError` should be raised out of the first `get_event` call.
+            exc = self.get_failure(get_event1, CancelledError).value
+            self.assertIsInstance(exc, CancelledError)
+            # The second `get_event` call should complete successfully.
+            self.get_success(get_event2)
+
+    def test_second_get_event_cancelled(self):
+        """Test cancellation of the second `get_event` call sharing a database fetch."""
+        with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
+            # Cancel the second `get_event` call.
+            get_event2.cancel()
+            # The first `get_event` call must not be cancelled.
+            self.assertNoResult(get_event1)
+            # The second `get_event` call gets cancelled immediately.
+            exc = self.get_failure(get_event2, CancelledError).value
+            self.assertIsInstance(exc, CancelledError)
+
+            # Unblock the database fetch.
+            unblock.callback(None)
+            # The first `get_event` call should complete successfully.
+            self.get_success(get_event1)