summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py112
1 files changed, 111 insertions, 1 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index cedbb9fafc..c1558c40c3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import cast
+from typing import Collection, Optional, cast
 from unittest import TestCase
 from unittest.mock import Mock, patch
 
+from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             f"Stale partial-stated room flag left over for {room_id} after a"
             f" failed do_invite_join!",
         )
+
+    def test_duplicate_partial_state_room_syncs(self) -> None:
+        """
+        Tests that concurrent partial state syncs are not started for the same room.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Try to start another partial state sync.
+            # Nothing should happen.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # End the partial state sync
+            is_partial_state = False
+            end_sync.callback(None)
+
+            # The partial state sync should not be restarted.
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # The next attempt to start the partial state sync should work.
+            is_partial_state = True
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+    def test_partial_state_room_sync_restart(self) -> None:
+        """
+        Tests that partial state syncs are restarted when a second partial state sync
+        was deduplicated and the first partial state sync fails.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Fail the partial state sync.
+            # The partial state sync should not be restarted.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Start the partial state sync again.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Deduplicate another partial state sync.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Fail the partial state sync.
+            # It should restart with the latest parameters.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+            mock_sync_partial_state_room.assert_called_with(
+                initial_destination="hs3",
+                other_destinations=["hs2"],
+                room_id="room_id",
+            )