From 8e9739449dd6d3c133adf9e995d27d06518a0bcf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 25 Apr 2023 13:30:41 -0400 Subject: Add unstable /keys/claim endpoint which always returns fallback keys. (#15462) It can be useful to always return the fallback key when attempting to claim keys. This adds an unstable endpoint for `/keys/claim` which always returns fallback keys in addition to one-time-keys. The fallback key(s) are not marked as "used" unless there are no corresponding OTKs. This is currently defined in MSC3983 (although likely to be split out to a separate MSC). The endpoint shape may change or be requested differently (i.e. a keyword parameter on the current endpoint), but the core logic should be reasonable. --- tests/handlers/test_e2e_keys.py | 241 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 6 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 013b9ee550..18edebd652 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res2 = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # claiming an OTK again should return the same fallback key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) + def test_fallback_key_always_returned(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + fallback_key = {"alg1:k1": "fallback_key1"} + otk = {"alg1:k2": "key2"} + + # we shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # Upload a OTK & fallback key. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": otk, "fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK and requesting to always return the fallback key should + # return both. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}}, + }, + ) + + # This should not mark the key as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Claiming an OTK again should return only the fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # And mark it as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(fallback_res, []) + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname @@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } }, timeout=None, + always_include_fallback_keys=False, ) ) self.assertEqual( @@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) + @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}}) + def test_query_appservice_with_fallback(self) -> None: + local_user = "@boris:" + self.hs.hostname + device_id_1 = "xyz" + fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}} + otk = {"alg1:k2": {"desc": "key2"}} + as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}} + as_otk = {"alg1:k4": {"desc": "key4"}} + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) + ) + + # Claim OTKs, which will ask the appservice and do nothing else. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **as_fallback_key}} + }, + }, + ) + + # Now upload a fallback key. + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(res, []) + + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # The appservice will return only the OTK. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_otk}}, []) + ) + + # Claim OTKs, which should return the OTK from the appservice and the + # uploaded fallback key. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": { + local_user: {device_id_1: {**as_otk, **fallback_key}} + }, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Now upload a OTK. + self.get_success( + self.handler.upload_keys_for_user( + local_user, + device_id_1, + {"one_time_keys": otk}, + ) + ) + + # Claim OTKs, which will return information only from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}}, + }, + ) + + # But the fallback key should not be marked as used. + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1) + ) + self.assertEqual(fallback_res, ["alg1"]) + + # Finally, return only the fallback key from the appservice. + self.appservice_api.claim_client_keys.return_value = make_awaitable( + ({local_user: {device_id_1: as_fallback_key}}, []) + ) + + # Claim OTKs, which will return only the fallback key from the database. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + timeout=None, + always_include_fallback_keys=True, + ) + ) + self.assertEqual( + claim_res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id_1: as_fallback_key}}, + }, + ) + @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) def test_query_local_devices_appservice(self) -> None: """Test that querying of appservices for keys overrides responses from the database.""" -- cgit 1.4.1