diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 013b9ee550..18edebd652 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# claiming an OTK again should return the same fallback key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
- {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
+ def test_fallback_key_always_returned(self) -> None:
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ otk = {"alg1:k2": "key2"}
+
+ # we shouldn't have any unused fallback keys yet
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # Upload a OTK & fallback key.
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"one_time_keys": otk, "fallback_keys": fallback_key},
+ )
+ )
+
+ # we should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"])
+
+ # Claiming an OTK and requesting to always return the fallback key should
+ # return both.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
+ },
+ )
+
+ # This should not mark the key as used.
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"])
+
+ # Claiming an OTK again should return only the fallback key.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # And mark it as used.
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(fallback_res, [])
+
def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
@@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
timeout=None,
+ always_include_fallback_keys=False,
)
)
self.assertEqual(
@@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
+ @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+ def test_query_appservice_with_fallback(self) -> None:
+ local_user = "@boris:" + self.hs.hostname
+ device_id_1 = "xyz"
+ fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
+ otk = {"alg1:k2": {"desc": "key2"}}
+ as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
+ as_otk = {"alg1:k4": {"desc": "key4"}}
+
+ # Inject an appservice interested in this user.
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+ self.hs.get_datastores().main.services_cache = [appservice]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [appservice]
+ )
+
+ # Setup a response.
+ self.appservice_api.claim_client_keys.return_value = make_awaitable(
+ ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
+ )
+
+ # Claim OTKs, which will ask the appservice and do nothing else.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {
+ "failures": {},
+ "one_time_keys": {
+ local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+ },
+ },
+ )
+
+ # Now upload a fallback key.
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+ )
+ self.assertEqual(res, [])
+
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id_1,
+ {"fallback_keys": fallback_key},
+ )
+ )
+
+ # we should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+ )
+ self.assertEqual(fallback_res, ["alg1"])
+
+ # The appservice will return only the OTK.
+ self.appservice_api.claim_client_keys.return_value = make_awaitable(
+ ({local_user: {device_id_1: as_otk}}, [])
+ )
+
+ # Claim OTKs, which should return the OTK from the appservice and the
+ # uploaded fallback key.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {
+ "failures": {},
+ "one_time_keys": {
+ local_user: {device_id_1: {**as_otk, **fallback_key}}
+ },
+ },
+ )
+
+ # But the fallback key should not be marked as used.
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+ )
+ self.assertEqual(fallback_res, ["alg1"])
+
+ # Now upload a OTK.
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id_1,
+ {"one_time_keys": otk},
+ )
+ )
+
+ # Claim OTKs, which will return information only from the database.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
+ },
+ )
+
+ # But the fallback key should not be marked as used.
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+ )
+ self.assertEqual(fallback_res, ["alg1"])
+
+ # Finally, return only the fallback key from the appservice.
+ self.appservice_api.claim_client_keys.return_value = make_awaitable(
+ ({local_user: {device_id_1: as_fallback_key}}, [])
+ )
+
+ # Claim OTKs, which will return only the fallback key from the database.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+ timeout=None,
+ always_include_fallback_keys=True,
+ )
+ )
+ self.assertEqual(
+ claim_res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id_1: as_fallback_key}},
+ },
+ )
+
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
def test_query_local_devices_appservice(self) -> None:
"""Test that querying of appservices for keys overrides responses from the database."""
|