diff options
Diffstat (limited to 'tests/handlers/test_federation.py')
-rw-r--r-- | tests/handlers/test_federation.py | 112 |
1 files changed, 111 insertions, 1 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index cedbb9fafc..c1558c40c3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import cast +from typing import Collection, Optional, cast from unittest import TestCase from unittest.mock import Mock, patch +from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes @@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): f"Stale partial-stated room flag left over for {room_id} after a" f" failed do_invite_join!", ) + + def test_duplicate_partial_state_room_syncs(self) -> None: + """ + Tests that concurrent partial state syncs are not started for the same room. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Try to start another partial state sync. + # Nothing should happen. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # End the partial state sync + is_partial_state = False + end_sync.callback(None) + + # The partial state sync should not be restarted. + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # The next attempt to start the partial state sync should work. + is_partial_state = True + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + def test_partial_state_room_sync_restart(self) -> None: + """ + Tests that partial state syncs are restarted when a second partial state sync + was deduplicated and the first partial state sync fails. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Fail the partial state sync. + # The partial state sync should not be restarted. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Start the partial state sync again. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Deduplicate another partial state sync. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Fail the partial state sync. + # It should restart with the latest parameters. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 3) + mock_sync_partial_state_room.assert_called_with( + initial_destination="hs3", + other_destinations=["hs2"], + room_id="room_id", + ) |