diff options
Diffstat (limited to 'tests/test_federation.py')
-rw-r--r-- | tests/test_federation.py | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/tests/test_federation.py b/tests/test_federation.py index ddb43c8c98..82dfd88b99 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Collection, List, Optional, Union from unittest.mock import Mock -from twisted.internet.defer import succeed from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import FederationError -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json +from synapse.handlers.device import DeviceListUpdater from synapse.http.types import QueryParams from synapse.logging.context import LoggingContext from synapse.server import HomeServer @@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] self.client = self.hs.get_federation_client() - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( - lambda dest, pdus, **k: succeed(pdus) - ) + + async def _check_sigs_and_hash_for_pulled_events_and_fetch( + dest: str, pdus: Collection[EventBase], room_version: RoomVersion + ) -> List[EventBase]: + return list(pdus) + + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] # Send the join, it should return None (which is not an error) self.assertEqual( @@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register the mock on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. @@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. device_list_updater = self.hs.get_device_handler().device_list_updater + assert isinstance(device_list_updater, DeviceListUpdater) self.get_success( device_list_updater.incoming_device_list_update( origin=remote_origin, @@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock( + federation_client.query_user_devices = Mock( # type: ignore[assignment] return_value=make_awaitable( { "user_id": remote_user_id, @@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): keys = self.get_success( self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), ) - self.assertTrue(remote_user_id in keys) + self.assertIn(remote_user_id, keys) + key = keys[remote_user_id] + assert key is not None # Check that the master key is the one returned by the mock. - master_key = keys[remote_user_id]["master"] + master_key = key["master"] self.assertEqual(len(master_key["keys"]), 1) self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) self.assertTrue(remote_master_key in master_key["keys"].values()) # Check that the self-signing key is the one returned by the mock. - self_signing_key = keys[remote_user_id]["self_signing"] + self_signing_key = key["self_signing"] self.assertEqual(len(self_signing_key["keys"]), 1) self.assertTrue( "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), |