summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-25 13:30:41 -0400
committerGitHub <noreply@github.com>2023-04-25 13:30:41 -0400
commit8e9739449dd6d3c133adf9e995d27d06518a0bcf (patch)
treefc3a5f11b23315b18ea87b0217b1dbd954c058ed /tests
parentMerge branch 'master' into develop (diff)
downloadsynapse-8e9739449dd6d3c133adf9e995d27d06518a0bcf.tar.xz
Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to
claim keys. This adds an unstable endpoint for `/keys/claim` which
always returns fallback keys in addition to one-time-keys.

The fallback key(s) are not marked as "used" unless there are no
corresponding OTKs.

This is currently defined in MSC3983 (although likely to be split out
to a separate MSC). The endpoint shape may change or be requested
differently (i.e. a keyword parameter on the current endpoint), but the
core logic should be reasonable.
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py241
1 files changed, 235 insertions, 6 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 013b9ee550..18edebd652 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # claiming an OTK again should return the same fallback key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
+    def test_fallback_key_always_returned(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        otk = {"alg1:k2": "key2"}
+
+        # we shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a OTK & fallback key.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"one_time_keys": otk, "fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK and requesting to always return the fallback key should
+        # return both.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
+            },
+        )
+
+        # This should not mark the key as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK again should return only the fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # And mark it as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, [])
+
     def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                     }
                 },
                 timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+    def test_query_appservice_with_fallback(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id_1 = "xyz"
+        fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
+        otk = {"alg1:k2": {"desc": "key2"}}
+        as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
+        as_otk = {"alg1:k4": {"desc": "key4"}}
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
+        )
+
+        # Claim OTKs, which will ask the appservice and do nothing else.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+                },
+            },
+        )
+
+        # Now upload a fallback key.
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(res, [])
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # The appservice will return only the OTK.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_otk}}, [])
+        )
+
+        # Claim OTKs, which should return the OTK from the appservice and the
+        # uploaded fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **fallback_key}}
+                },
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Now upload a OTK.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"one_time_keys": otk},
+            )
+        )
+
+        # Claim OTKs, which will return information only from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Finally, return only the fallback key from the appservice.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_fallback_key}}, [])
+        )
+
+        # Claim OTKs, which will return only the fallback key from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: as_fallback_key}},
+            },
+        )
+
     @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
     def test_query_local_devices_appservice(self) -> None:
         """Test that querying of appservices for keys overrides responses from the database."""