summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/appservice/test_api.py59
-rw-r--r--tests/handlers/test_e2e_keys.py76
2 files changed, 134 insertions, 1 deletions
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 9d183b733e..0dd02b7d58 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(self.request_url, URL_LOCATION)
         self.assertEqual(result, SUCCESS_RESULT_LOCATION)
+
+    def test_claim_keys(self) -> None:
+        """
+        Tests that the /keys/claim response is properly parsed for missing
+        keys.
+        """
+
+        RESPONSE: JsonDict = {
+            "@alice:example.org": {
+                "DEVICE_1": {
+                    "signed_curve25519:AAAAHg": {
+                        # We don't really care about the content of the keys,
+                        # they get passed back transparently.
+                    },
+                    "signed_curve25519:BBBBHg": {},
+                },
+                "DEVICE_2": {"signed_curve25519:CCCCHg": {}},
+            },
+        }
+
+        async def post_json_get_json(
+            uri: str,
+            post_json: Any,
+            headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
+        ) -> JsonDict:
+            # Ensure the access token is passed as both a header and query arg.
+            if not headers.get("Authorization"):
+                raise RuntimeError("Access token not provided")
+
+            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            return RESPONSE
+
+        # We assign to a method, which mypy doesn't like.
+        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[assignment]
+
+        MISSING_KEYS = [
+            # Known user, known device, missing algorithm.
+            ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
+            # Known user, missing device.
+            ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
+            # Unknown user.
+            ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
+        ]
+
+        claimed_keys, missing = self.get_success(
+            self.api.claim_client_keys(
+                self.service,
+                [
+                    # Found devices
+                    ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
+                    ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
+                    ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
+                ]
+                + MISSING_KEYS,
+            )
+        )
+
+        self.assertEqual(claimed_keys, RESPONSE)
+        self.assertEqual(missing, MISSING_KEYS)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 6b4cba65d0..4ff04fc66b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
+from tests.unittest import override_config
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        return self.setup_test_homeserver(federation_client=mock.Mock())
+        self.appservice_api = mock.Mock()
+        return self.setup_test_homeserver(
+            federation_client=mock.Mock(), application_service_api=self.appservice_api
+        )
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
@@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
             # The two requests to the local homeserver should be identical.
             self.assertEqual(response_1, response_2)
+
+    @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+    def test_query_appservice(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id_1 = "xyz"
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        device_id_2 = "abc"
+        otk = {"alg1:k2": "key2"}
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response, but only for device 2.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
+        )
+
+        # we shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(res, [])
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # claiming an OTK when no OTKs are available should ask the appservice, then
+        # query the fallback keys.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {
+                    "one_time_keys": {
+                        local_user: {device_id_1: "alg1", device_id_2: "alg1"}
+                    }
+                },
+                timeout=None,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: fallback_key, device_id_2: otk}
+                },
+            },
+        )