diff --git a/tests/test_federation.py b/tests/test_federation.py
index ddb43c8c98..82dfd88b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Union
+from typing import Collection, List, Optional, Union
from unittest.mock import Mock
-from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
@@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) -> None:
pass
- federation_event_handler._check_event_auth = _check_event_auth
+ federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
self.client = self.hs.get_federation_client()
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
- lambda dest, pdus, **k: succeed(pdus)
- )
+
+ async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+ dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+ ) -> List[EventBase]:
+ return list(pdus)
+
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
# Send the join, it should return None (which is not an error)
self.assertEqual(
@@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register the mock on the federation client.
federation_client = self.hs.get_federation_client()
- federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
@@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
device_list_updater = self.hs.get_device_handler().device_list_updater
+ assert isinstance(device_list_updater, DeviceListUpdater)
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
@@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.hs.get_federation_client()
- federation_client.query_user_devices = Mock(
+ federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
@@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
keys = self.get_success(
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
)
- self.assertTrue(remote_user_id in keys)
+ self.assertIn(remote_user_id, keys)
+ key = keys[remote_user_id]
+ assert key is not None
# Check that the master key is the one returned by the mock.
- master_key = keys[remote_user_id]["master"]
+ master_key = key["master"]
self.assertEqual(len(master_key["keys"]), 1)
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
self.assertTrue(remote_master_key in master_key["keys"].values())
# Check that the self-signing key is the one returned by the mock.
- self_signing_key = keys[remote_user_id]["self_signing"]
+ self_signing_key = key["self_signing"]
self.assertEqual(len(self_signing_key["keys"]), 1)
self.assertTrue(
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
|