summary refs log tree commit diff
path: root/tests/test_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_federation.py')
-rw-r--r--tests/test_federation.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ddb43c8c98..82dfd88b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,17 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional, Union
+from typing import Collection, List, Optional, Union
 from unittest.mock import Mock
 
-from twisted.internet.defer import succeed
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
 from synapse.http.types import QueryParams
 from synapse.logging.context import LoggingContext
 from synapse.server import HomeServer
@@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         ) -> None:
             pass
 
-        federation_event_handler._check_event_auth = _check_event_auth
+        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[assignment]
         self.client = self.hs.get_federation_client()
-        self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
-            lambda dest, pdus, **k: succeed(pdus)
-        )
+
+        async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+            dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+        ) -> List[EventBase]:
+            return list(pdus)
+
+        self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch  # type: ignore[assignment]
 
         # Send the join, it should return None (which is not an error)
         self.assertEqual(
@@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register the mock on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[assignment]
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
@@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
         device_list_updater = self.hs.get_device_handler().device_list_updater
+        assert isinstance(device_list_updater, DeviceListUpdater)
         self.get_success(
             device_list_updater.incoming_device_list_update(
                 origin=remote_origin,
@@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register mock device list retrieval on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(
+        federation_client.query_user_devices = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "user_id": remote_user_id,
@@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         keys = self.get_success(
             self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
         )
-        self.assertTrue(remote_user_id in keys)
+        self.assertIn(remote_user_id, keys)
+        key = keys[remote_user_id]
+        assert key is not None
 
         # Check that the master key is the one returned by the mock.
-        master_key = keys[remote_user_id]["master"]
+        master_key = key["master"]
         self.assertEqual(len(master_key["keys"]), 1)
         self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
         self.assertTrue(remote_master_key in master_key["keys"].values())
 
         # Check that the self-signing key is the one returned by the mock.
-        self_signing_key = keys[remote_user_id]["self_signing"]
+        self_signing_key = key["self_signing"]
         self.assertEqual(len(self_signing_key["keys"]), 1)
         self.assertTrue(
             "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),