diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index cedbb9fafc..c1558c40c3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import cast
+from typing import Collection, Optional, cast
from unittest import TestCase
from unittest.mock import Mock, patch
+from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
f"Stale partial-stated room flag left over for {room_id} after a"
f" failed do_invite_join!",
)
+
+ def test_duplicate_partial_state_room_syncs(self) -> None:
+ """
+ Tests that concurrent partial state syncs are not started for the same room.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Try to start another partial state sync.
+ # Nothing should happen.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # End the partial state sync
+ is_partial_state = False
+ end_sync.callback(None)
+
+ # The partial state sync should not be restarted.
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # The next attempt to start the partial state sync should work.
+ is_partial_state = True
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ def test_partial_state_room_sync_restart(self) -> None:
+ """
+ Tests that partial state syncs are restarted when a second partial state sync
+ was deduplicated and the first partial state sync fails.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Fail the partial state sync.
+ # The partial state sync should not be restarted.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Start the partial state sync again.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Deduplicate another partial state sync.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Fail the partial state sync.
+ # It should restart with the latest parameters.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+ mock_sync_partial_state_room.assert_called_with(
+ initial_destination="hs3",
+ other_destinations=["hs2"],
+ room_id="room_id",
+ )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9919938e80..8f88c0117d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.send_local_online_presence_to([remote_user_id])
)
+ # We don't always send out federation immediately, so we advance the clock.
+ self.reactor.advance(1000)
+
# Check that a presence update was sent as part of a federation transaction
found_update = False
calls = (
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 555922409d..6e4055cc21 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -14,7 +14,7 @@
from twisted.internet import defer
-from synapse.replication.tcp.commands import PositionCommand, RdataCommand
+from synapse.replication.tcp.commands import PositionCommand
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
- cmd_handler.send_command(
- RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
- )
- self.replicate()
-
self.get_success(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
)
- # `wait_for_stream_position` should only return once master receives an
- # RDATA from the worker
- ctx = cache_id_gen.get_next()
- next_token = self.get_success(ctx.__aenter__())
- self.get_success(ctx.__aexit__(None, None, None))
+ # `wait_for_stream_position` should only return once master receives a
+ # notification that `next_token` has persisted.
+ ctx_worker1 = cache_id_gen.get_next()
+ next_token = self.get_success(ctx_worker1.__aenter__())
d = defer.ensureDeferred(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
@@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertFalse(d.called)
- # ... but receiving the RDATA should
- cmd_handler.send_command(
- RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
- )
- self.replicate()
+ # ... but worker1 finishing (and so sending an update) should.
+ self.get_success(ctx_worker1.__aexit__(None, None, None))
self.assertTrue(d.called)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index ff9691c518..9174fb0964 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
return StreamIdGenerator(
db_conn=conn,
+ notifier=self.hs.get_replication_notifier(),
table="foobar",
column="stream_id",
)
@@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
@@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[
|