diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..70fc4263e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import time
from typing import Dict, Iterable
from unittest import mock
@@ -151,18 +152,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {"alg1:k1": "key1"}
-
res = self.get_success(
self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
)
- res2 = self.get_success(
+ # Keys should be returned in the order they were uploaded. To test, advance time
+ # a little, then upload a second key with an earlier key ID; it should get
+ # returned second.
+ self.reactor.advance(1)
+ res = self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
+ )
+ )
+ self.assertDictEqual(
+ res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
+ )
+
+ # now claim both keys back. They should be in the same order
+ res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
@@ -171,12 +184,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
self.assertEqual(
- res2,
+ res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
+ res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {local_user: {device_id: {"alg1": 1}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
+ },
+ )
def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
@@ -336,6 +364,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)
+ def test_claim_one_time_key_bulk_ordering(self) -> None:
+ """Keys returned by the bulk claim call should be returned in the correct order"""
+
+ # Alice has lots of keys, uploaded in a specific order
+ alice = f"@alice:{self.hs.hostname}"
+ alice_dev = "alice_dev_1"
+
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
+ )
+ )
+ # Advance time by 1s, to ensure that there is a difference in upload time.
+ self.reactor.advance(1)
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
+ )
+ )
+
+ # Now claim some, and check we get the right ones.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {alice: {alice_dev: {"alg1": 2}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ # We should get the first-uploaded keys, even though they have later key ids.
+ # We should get a random set of two of k20, k21, k22.
+ self.assertEqual(claim_res["failures"], {})
+ claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
+ self.assertEqual(len(claimed_keys), 2)
+ for key_id in claimed_keys.keys():
+ self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
+
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
@@ -1758,3 +1827,222 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertIs(exists, True)
self.assertIs(replaceable_without_uia, False)
+
+ def test_delete_old_one_time_keys(self) -> None:
+ """Test the db migration that clears out old OTKs"""
+
+ # We upload two sets of keys, one just over a week ago, and one just less than
+ # a week ago. Each batch contains some keys that match the deletion pattern
+ # (key IDs of 6 chars), and some that do not.
+ #
+ # Finally, set the scheduled task going, and check what gets deleted.
+
+ user_id = "@user000:" + self.hs.hostname
+ device_id = "xyz"
+
+ # The scheduled task should be for "now" in real, wallclock time, so
+ # set the test reactor to just over a week ago.
+ self.reactor.advance(time.time() - 7.5 * 24 * 3600)
+
+ # Upload some keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys to delete
+ "alg1:AAAAAA": "key1",
+ "alg2:AAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:AAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # A day passes
+ self.reactor.advance(24 * 3600)
+
+ # Upload some more keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys which match the pattern
+ "alg1:BAAAAA": "key1",
+ "alg2:BAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:BAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # The rest of the week passes, which should set the scheduled task going.
+ self.reactor.advance(6.5 * 24 * 3600)
+
+ # Check what we're left with in the database
+ remaining_key_ids = {
+ row[0]
+ for row in self.get_success(
+ self.handler.store.db_pool.simple_select_list(
+ "e2e_one_time_keys_json", None, ["key_id"]
+ )
+ )
+ }
+ self.assertEqual(
+ remaining_key_ids, {"AAAAAAAAAA", "BAAAAA", "BAAAAB", "BAAAAAAAAA"}
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_not_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we don't share a room
+ with returns nothing.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Do *not* pretend we're sharing a room with the user we're querying.
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(
+ query_result,
+ {
+ "device_keys": {},
+ "failures": {},
+ "master_keys": {},
+ "self_signing_keys": {},
+ "user_signing_keys": {},
+ },
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we share a room
+ with returns the cross signing keys correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `query_devices` will filter out the user ID and `_query_devices_for_destination`
+ # will return early.
+ self.store.do_users_share_a_room_joined_or_invited = mock.AsyncMock( # type: ignore[method-assign]
+ return_value=[remote_user_id]
+ )
+ self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ }
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
|