summary refs log tree commit diff
path: root/tests/handlers/test_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_federation.py')
-rw-r--r--tests/handlers/test_federation.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         # We mock out the FederationClient.backfill method, to pretend that a remote
         # server has returned our fake event.
         federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
-        self.hs.get_federation_client().backfill = federation_client_backfill_mock
+        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[assignment]
 
         # We also mock the persist method with a side effect of itself. This allows us
         # to track when it has been called while preserving its function.
         persist_events_and_notify_mock = Mock(
             side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
         )
-        self.hs.get_federation_event_handler().persist_events_and_notify = (
+        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[assignment]
             persist_events_and_notify_mock
         )
 
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
         fed_client = fed_handler.federation_client
 
         room_id = "!room:example.com"
-        membership_event = make_event_from_dict(
-            {
-                "room_id": room_id,
-                "type": "m.room.member",
-                "sender": "@alice:test",
-                "state_key": "@alice:test",
-                "content": {"membership": "join"},
-            },
-            RoomVersions.V10,
-        )
-
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    "example.com",
-                    membership_event,
-                    RoomVersions.V10,
-                )
-            )
-        )
 
         EVENT_CREATE = make_event_from_dict(
             {
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             },
             room_version=RoomVersions.V10,
         )
+        membership_event = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@alice:test",
+                "state_key": "@alice:test",
+                "content": {"membership": "join"},
+                "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+            },
+            RoomVersions.V10,
+        )
+        mock_make_membership_event = Mock(
+            return_value=make_awaitable(
+                (
+                    "example.com",
+                    membership_event,
+                    RoomVersions.V10,
+                )
+            )
+        )
         mock_send_join = Mock(
             return_value=make_awaitable(
                 SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Try to start another partial state sync.
             # Nothing should happen.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
 
             # The next attempt to start the partial state sync should work.
             is_partial_state = True
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
     def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Start the partial state sync again.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Deduplicate another partial state sync.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 3)
             mock_sync_partial_state_room.assert_called_with(
                 initial_destination="hs3",
-                other_destinations=["hs2"],
+                other_destinations={"hs2"},
                 room_id="room_id",
             )