diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 24e405f429..90b4da9ad5 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
+ def test_claim_one_time_key_bulk(self) -> None:
+ """Like test_claim_one_time_key but claims multiple keys in one handler call."""
+ # Apologies to the reader. This test is a little too verbose. It is particularly
+ # tricky to make assertions neatly with all these nested dictionaries in play.
+
+ # Three users with two devices each. Each device uses two algorithms.
+ # Each algorithm is invoked with two keys.
+ alice = f"@alice:{self.hs.hostname}"
+ brian = f"@brian:{self.hs.hostname}"
+ chris = f"@chris:{self.hs.hostname}"
+ one_time_keys = {
+ alice: {
+ "alice_dev_1": {
+ "alg1:k1": {"dummy_id": 1},
+ "alg1:k2": {"dummy_id": 2},
+ "alg2:k3": {"dummy_id": 3},
+ "alg2:k4": {"dummy_id": 4},
+ },
+ "alice_dev_2": {
+ "alg1:k5": {"dummy_id": 5},
+ "alg1:k6": {"dummy_id": 6},
+ "alg2:k7": {"dummy_id": 7},
+ "alg2:k8": {"dummy_id": 8},
+ },
+ },
+ brian: {
+ "brian_dev_1": {
+ "alg1:k9": {"dummy_id": 9},
+ "alg1:k10": {"dummy_id": 10},
+ "alg2:k11": {"dummy_id": 11},
+ "alg2:k12": {"dummy_id": 12},
+ },
+ "brian_dev_2": {
+ "alg1:k13": {"dummy_id": 13},
+ "alg1:k14": {"dummy_id": 14},
+ "alg2:k15": {"dummy_id": 15},
+ "alg2:k16": {"dummy_id": 16},
+ },
+ },
+ chris: {
+ "chris_dev_1": {
+ "alg1:k17": {"dummy_id": 17},
+ "alg1:k18": {"dummy_id": 18},
+ "alg2:k19": {"dummy_id": 19},
+ "alg2:k20": {"dummy_id": 20},
+ },
+ "chris_dev_2": {
+ "alg1:k21": {"dummy_id": 21},
+ "alg1:k22": {"dummy_id": 22},
+ "alg2:k23": {"dummy_id": 23},
+ "alg2:k24": {"dummy_id": 24},
+ },
+ },
+ }
+ for user_id, devices in one_time_keys.items():
+ for device_id, keys_dict in devices.items():
+ counts = self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {"one_time_keys": keys_dict},
+ )
+ )
+ # The upload should report 2 keys per algorithm.
+ expected_counts = {
+ "one_time_key_counts": {
+ # See count_e2e_one_time_keys for why this is hardcoded.
+ "signed_curve25519": 0,
+ "alg1": 2,
+ "alg2": 2,
+ },
+ }
+ self.assertEqual(counts, expected_counts)
+
+ # Claim a variety of keys.
+ # Raw format, easier to make test assertions about.
+ claims_to_make = {
+ (alice, "alice_dev_1", "alg1"): 1,
+ (alice, "alice_dev_1", "alg2"): 2,
+ (alice, "alice_dev_2", "alg2"): 1,
+ (brian, "brian_dev_1", "alg1"): 2,
+ (brian, "brian_dev_2", "alg2"): 9001,
+ (chris, "chris_dev_2", "alg2"): 1,
+ }
+ # Convert to the format the handler wants.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for (user_id, device_id, algorithm), count in claims_to_make.items():
+ query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ query,
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+
+ # No failures, please!
+ self.assertEqual(claim_res["failures"], {})
+
+ # Check that we get exactly the (user, device, algorithm)s we asked for.
+ got_otks = claim_res["one_time_keys"]
+ claimed_user_device_algorithms = {
+ (user_id, device_id, alg_key_id.split(":")[0])
+ for user_id, devices in got_otks.items()
+ for device_id, key_dict in devices.items()
+ for alg_key_id in key_dict
+ }
+ self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
+
+ # Now check the keys we got are what we expected.
+ def assertExactlyOneOtk(
+ user_id: str, device_id: str, *alg_key_pairs: str
+ ) -> None:
+ key_dict = got_otks[user_id][device_id]
+ found = 0
+ for alg_key in alg_key_pairs:
+ if alg_key in key_dict:
+ expected_key_json = one_time_keys[user_id][device_id][alg_key]
+ self.assertEqual(key_dict[alg_key], expected_key_json)
+ found += 1
+ self.assertEqual(found, 1)
+
+ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
+ key_dict = got_otks[user_id][device_id]
+ for alg_key in alg_key_pairs:
+ expected_key_json = one_time_keys[user_id][device_id][alg_key]
+ self.assertEqual(key_dict[alg_key], expected_key_json)
+
+ # Expect a single arbitrary key to be returned.
+ assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
+ assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
+ assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
+
+ assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
+ assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
+ assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
+
+ # Now check the unused key counts.
+ for user_id, devices in one_time_keys.items():
+ for device_id in devices:
+ counts_by_alg = self.get_success(
+ self.store.count_e2e_one_time_keys(user_id, device_id)
+ )
+ # Somewhat fiddley to compute the expected count dict.
+ expected_counts_by_alg = {
+ "signed_curve25519": 0,
+ }
+ for alg in ["alg1", "alg2"]:
+ claim_count = claims_to_make.get((user_id, device_id, alg), 0)
+ remaining_count = max(0, 2 - claim_count)
+ if remaining_count > 0:
+ expected_counts_by_alg[alg] = remaining_count
+
+ self.assertEqual(
+ counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
+ )
+
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
|