diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/storage/databases/main/test_events_worker.py | 121 |
1 files changed, 119 insertions, 2 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 1f6a9eb07b..bf6374f93d 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -13,10 +13,11 @@ # limitations under the License. import json from contextlib import contextmanager -from typing import Generator +from typing import Generator, Tuple +from unittest import mock from twisted.enterprise.adbapi import ConnectionPool -from twisted.internet.defer import ensureDeferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import EventFormatVersions, RoomVersions @@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): # This next event fetch should succeed self.get_success(self.store.get_event(self.event_ids[0])) + + +class GetEventCancellationTestCase(unittest.HomeserverTestCase): + """Test cancellation of `get_event` calls.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store: EventsWorkerStore = hs.get_datastores().main + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + @contextmanager + def blocking_get_event_calls( + self, + ) -> Generator[ + Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None + ]: + """Starts two concurrent `get_event` calls for the same event. + + Both `get_event` calls will use the same database fetch, which will be blocked + at the time this function returns. + + Returns: + A tuple containing: + * A `Deferred` that unblocks the database fetch. + * A cancellable `Deferred` for the first `get_event` call. + * A cancellable `Deferred` for the second `get_event` call. + """ + # Patch `DatabasePool.runWithConnection` to block. + unblock: "Deferred[None]" = Deferred() + original_runWithConnection = self.store.db_pool.runWithConnection + + async def runWithConnection(*args, **kwargs): + await unblock + return await original_runWithConnection(*args, **kwargs) + + with mock.patch.object( + self.store.db_pool, + "runWithConnection", + new=runWithConnection, + ): + ctx1 = LoggingContext("get_event1") + ctx2 = LoggingContext("get_event2") + + async def get_event(ctx: LoggingContext) -> None: + with ctx: + await self.store.get_event(self.event_id) + + get_event1 = ensureDeferred(get_event(ctx1)) + get_event2 = ensureDeferred(get_event(ctx2)) + + # Both `get_event` calls ought to be blocked. + self.assertNoResult(get_event1) + self.assertNoResult(get_event2) + + yield unblock, get_event1, get_event2 + + # Confirm that the two `get_event` calls shared the same database fetch. + self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1) + self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0) + + def test_first_get_event_cancelled(self): + """Test cancellation of the first `get_event` call sharing a database fetch. + + The first `get_event` call is the one which initiates the fetch. We expect the + fetch to complete despite the cancellation. Furthermore, the first `get_event` + call must not abort before the fetch is complete, otherwise the fetch will be + using a finished logging context. + """ + with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): + # Cancel the first `get_event` call. + get_event1.cancel() + # The first `get_event` call must not abort immediately, otherwise its + # logging context will be finished while it is still in use by the database + # fetch. + self.assertNoResult(get_event1) + # The second `get_event` call must not be cancelled. + self.assertNoResult(get_event2) + + # Unblock the database fetch. + unblock.callback(None) + # A `CancelledError` should be raised out of the first `get_event` call. + exc = self.get_failure(get_event1, CancelledError).value + self.assertIsInstance(exc, CancelledError) + # The second `get_event` call should complete successfully. + self.get_success(get_event2) + + def test_second_get_event_cancelled(self): + """Test cancellation of the second `get_event` call sharing a database fetch.""" + with self.blocking_get_event_calls() as (unblock, get_event1, get_event2): + # Cancel the second `get_event` call. + get_event2.cancel() + # The first `get_event` call must not be cancelled. + self.assertNoResult(get_event1) + # The second `get_event` call gets cancelled immediately. + exc = self.get_failure(get_event2, CancelledError).value + self.assertIsInstance(exc, CancelledError) + + # Unblock the database fetch. + unblock.callback(None) + # The first `get_event` call should complete successfully. + self.get_success(get_event1) |