summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-01-20 16:06:52 +0000
committerErik Johnston <erik@matrix.org>2023-01-20 16:06:52 +0000
commitbc8136dd81b049b0ec7934d5b901e897a0740147 (patch)
tree894f71640642a5bf444d475bbb7831cc512d9b13 /tests
parentNewsfile (diff)
parentNewsfile (diff)
downloadsynapse-github/erikj/fix_wait_for_stream.tar.xz
Merge branch 'erikj/repl_notifieri' into erikj/fix_wait_for_stream github/erikj/fix_wait_for_stream erikj/fix_wait_for_stream
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py112
-rw-r--r--tests/module_api/test_api.py3
-rw-r--r--tests/replication/tcp/test_handler.py23
-rw-r--r--tests/storage/test_id_generators.py4
4 files changed, 125 insertions, 17 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py

index cedbb9fafc..c1558c40c3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import cast +from typing import Collection, Optional, cast from unittest import TestCase from unittest.mock import Mock, patch +from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes @@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): f"Stale partial-stated room flag left over for {room_id} after a" f" failed do_invite_join!", ) + + def test_duplicate_partial_state_room_syncs(self) -> None: + """ + Tests that concurrent partial state syncs are not started for the same room. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Try to start another partial state sync. + # Nothing should happen. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # End the partial state sync + is_partial_state = False + end_sync.callback(None) + + # The partial state sync should not be restarted. + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # The next attempt to start the partial state sync should work. + is_partial_state = True + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + def test_partial_state_room_sync_restart(self) -> None: + """ + Tests that partial state syncs are restarted when a second partial state sync + was deduplicated and the first partial state sync fails. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Fail the partial state sync. + # The partial state sync should not be restarted. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Start the partial state sync again. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Deduplicate another partial state sync. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Fail the partial state sync. + # It should restart with the latest parameters. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 3) + mock_sync_partial_state_room.assert_called_with( + initial_destination="hs3", + other_destinations=["hs2"], + room_id="room_id", + ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9919938e80..8f88c0117d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We don't always send out federation immediately, so we advance the clock. + self.reactor.advance(1000) + # Check that a presence update was sent as part of a federation transaction found_update = False calls = ( diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 555922409d..6e4055cc21 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py
@@ -14,7 +14,7 @@ from twisted.internet import defer -from synapse.replication.tcp.commands import PositionCommand, RdataCommand +from synapse.replication.tcp.commands import PositionCommand from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): next_token = self.get_success(ctx.__aenter__()) self.get_success(ctx.__aexit__(None, None, None)) - cmd_handler.send_command( - RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) - ) - self.replicate() - self.get_success( data_handler.wait_for_stream_position("worker1", "caches", next_token) ) - # `wait_for_stream_position` should only return once master receives an - # RDATA from the worker - ctx = cache_id_gen.get_next() - next_token = self.get_success(ctx.__aenter__()) - self.get_success(ctx.__aexit__(None, None, None)) + # `wait_for_stream_position` should only return once master receives a + # notification that `next_token` has persisted. + ctx_worker1 = cache_id_gen.get_next() + next_token = self.get_success(ctx_worker1.__aenter__()) d = defer.ensureDeferred( data_handler.wait_for_stream_position("worker1", "caches", next_token) @@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): ) self.assertFalse(d.called) - # ... but receiving the RDATA should - cmd_handler.send_command( - RdataCommand("caches", "worker1", next_token, ("func_name", [], 0)) - ) - self.replicate() + # ... but worker1 finishing (and so sending an update) should. + self.get_success(ctx_worker1.__aexit__(None, None, None)) self.assertTrue(d.called) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index ff9691c518..9174fb0964 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase): def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: return StreamIdGenerator( db_conn=conn, + notifier=self.hs.get_replication_notifier(), table="foobar", column="stream_id", ) @@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], @@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): return MultiWriterIdGenerator( conn, self.db_pool, + notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, tables=[