From 5282ba1e2bbff2635dc09aec45fd42a56c1a4545 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Mar 2023 14:26:27 -0400 Subject: Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314) Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse. --- tests/appservice/test_api.py | 59 ++++++++++++++++++++++++++++++++ tests/handlers/test_e2e_keys.py | 76 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 9d183b733e..0dd02b7d58 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): ) self.assertEqual(self.request_url, URL_LOCATION) self.assertEqual(result, SUCCESS_RESULT_LOCATION) + + def test_claim_keys(self) -> None: + """ + Tests that the /keys/claim response is properly parsed for missing + keys. + """ + + RESPONSE: JsonDict = { + "@alice:example.org": { + "DEVICE_1": { + "signed_curve25519:AAAAHg": { + # We don't really care about the content of the keys, + # they get passed back transparently. + }, + "signed_curve25519:BBBBHg": {}, + }, + "DEVICE_2": {"signed_curve25519:CCCCHg": {}}, + }, + } + + async def post_json_get_json( + uri: str, + post_json: Any, + headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + ) -> JsonDict: + # Ensure the access token is passed as both a header and query arg. + if not headers.get("Authorization"): + raise RuntimeError("Access token not provided") + + self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + return RESPONSE + + # We assign to a method, which mypy doesn't like. + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + + MISSING_KEYS = [ + # Known user, known device, missing algorithm. + ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), + # Known user, missing device. + ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), + # Unknown user. + ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), + ] + + claimed_keys, missing = self.get_success( + self.api.claim_client_keys( + self.service, + [ + # Found devices + ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), + ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), + ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), + ] + + MISSING_KEYS, + ) + ) + + self.assertEqual(claimed_keys, RESPONSE) + self.assertEqual(missing, MISSING_KEYS) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6b4cba65d0..4ff04fc66b 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer +from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.types import JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=mock.Mock()) + self.appservice_api = mock.Mock() + return self.setup_test_homeserver( + federation_client=mock.Mock(), application_service_api=self.appservice_api + ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() @@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # The two requests to the local homeserver should be identical. self.assertEqual(response_1, response_2) + + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + device_id_2 = "abc" + otk = {"alg1:k2": "key2"} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response, but only for device 2. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ) + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # claiming an OTK when no OTKs are available should ask the appservice, then + # query the fallback keys. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + { + "one_time_keys": { + local_user: {device_id_1: "alg1", device_id_2: "alg1"} + } + }, + timeout=None, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: fallback_key, device_id_2: otk} + }, + }, + ) -- cgit 1.4.1