summary refs log tree commit diff
path: root/tests/handlers/test_e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_e2e_keys.py')
-rw-r--r--tests/handlers/test_e2e_keys.py130
1 files changed, 62 insertions, 68 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 2eaffe511e..7917766a08 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Iterable
+from typing import Dict, Iterable
 from unittest import mock
 
 from parameterized import parameterized
@@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.appservice_api = mock.Mock()
+        self.appservice_api = mock.AsyncMock()
         return self.setup_test_homeserver(
             federation_client=mock.Mock(), application_service_api=self.appservice_api
         )
@@ -801,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_client_keys = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
-                    "device_keys": {remote_user_id: {}},
-                    "master_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["master"],
-                            "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                        },
-                    },
-                    "self_signing_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["self_signing"],
-                            "keys": {
-                                "ed25519:"
-                                + remote_self_signing_key: remote_self_signing_key
-                            },
-                        }
+        self.hs.get_federation_client().query_client_keys = mock.AsyncMock(  # type: ignore[assignment]
+            return_value={
+                "device_keys": {remote_user_id: {}},
+                "master_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
                     },
-                }
-            )
+                },
+                "self_signing_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    }
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -874,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         # Pretend we're sharing a room with the user we're querying. If not,
         # `_query_devices_for_destination` will return early.
-        self.store.get_rooms_for_user = mock.Mock(
-            return_value=make_awaitable({"some_room_id"})
-        )
+        self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
 
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_user_devices = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
+        self.hs.get_federation_client().query_user_devices = mock.AsyncMock(  # type: ignore[assignment]
+            return_value={
+                "user_id": remote_user_id,
+                "stream_id": 1,
+                "devices": [],
+                "master_key": {
                     "user_id": remote_user_id,
-                    "stream_id": 1,
-                    "devices": [],
-                    "master_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                    },
-                    "self_signing_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["self_signing"],
-                        "keys": {
-                            "ed25519:"
-                            + remote_self_signing_key: remote_self_signing_key
-                        },
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+                "self_signing_key": {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
                     },
-                }
-            )
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         mock_get_rooms = mock.patch.object(
             self.store,
             "get_rooms_for_user",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(["some_room_id"]),
+            new_callable=mock.AsyncMock,
+            return_value=["some_room_id"],
         )
         mock_get_users = mock.patch.object(
             self.store,
             "get_users_server_still_shares_room_with",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable({remote_user_id}),
+            new_callable=mock.AsyncMock,
+            return_value={remote_user_id},
         )
         mock_request = mock.patch.object(
             self.hs.get_federation_client(),
             "query_user_devices",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(response_body),
+            new_callable=mock.AsyncMock,
+            return_value=response_body,
         )
 
         with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response, but only for device 2.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_2: otk}},
+            [(local_user, device_id_1, "alg1", 1)],
         )
 
         # we shouldn't have any unused fallback keys yet
@@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
-        )
+        response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
+            local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+        }
+        self.appservice_api.claim_client_keys.return_value = (response, [])
 
         # Claim OTKs, which will ask the appservice and do nothing else.
         claim_res = self.get_success(
@@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # The appservice will return only the OTK.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_otk}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_otk}},
+            [],
         )
 
         # Claim OTKs, which should return the OTK from the appservice and the
@@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # Finally, return only the fallback key from the appservice.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_fallback_key}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_fallback_key}},
+            [],
         )
 
         # Claim OTKs, which will return only the fallback key from the database.
@@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.query_keys.return_value = make_awaitable(
-            {
-                "device_keys": {
-                    local_user: {device_2: device_key_2b, device_3: device_key_3}
-                }
+        self.appservice_api.query_keys.return_value = {
+            "device_keys": {
+                local_user: {device_2: device_key_2b, device_3: device_key_3}
             }
-        )
+        }
 
         # Request all devices.
         res = self.get_success(self.handler.query_local_devices({local_user: None}))