summary refs log tree commit diff
path: root/tests/handlers/test_e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_e2e_keys.py')
-rw-r--r--tests/handlers/test_e2e_keys.py241
1 files changed, 235 insertions, 6 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 013b9ee550..18edebd652 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # claiming an OTK again should return the same fallback key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
+    def test_fallback_key_always_returned(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        otk = {"alg1:k2": "key2"}
+
+        # we shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a OTK & fallback key.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"one_time_keys": otk, "fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK and requesting to always return the fallback key should
+        # return both.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
+            },
+        )
+
+        # This should not mark the key as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK again should return only the fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # And mark it as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, [])
+
     def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                     }
                 },
                 timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+    def test_query_appservice_with_fallback(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id_1 = "xyz"
+        fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
+        otk = {"alg1:k2": {"desc": "key2"}}
+        as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
+        as_otk = {"alg1:k4": {"desc": "key4"}}
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
+        )
+
+        # Claim OTKs, which will ask the appservice and do nothing else.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+                },
+            },
+        )
+
+        # Now upload a fallback key.
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(res, [])
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # The appservice will return only the OTK.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_otk}}, [])
+        )
+
+        # Claim OTKs, which should return the OTK from the appservice and the
+        # uploaded fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **fallback_key}}
+                },
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Now upload a OTK.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"one_time_keys": otk},
+            )
+        )
+
+        # Claim OTKs, which will return information only from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Finally, return only the fallback key from the appservice.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_fallback_key}}, [])
+        )
+
+        # Claim OTKs, which will return only the fallback key from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: as_fallback_key}},
+            },
+        )
+
     @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
     def test_query_local_devices_appservice(self) -> None:
         """Test that querying of appservices for keys overrides responses from the database."""