diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = (
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
- other_destinations=["hs2"],
+ other_destinations={"hs2"},
room_id="room_id",
)
|