diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index c5556f2844..24e405f429 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
+ def test_fallback_key_bulk(self) -> None:
+ """Like test_fallback_key, but claims multiple keys in one handler call."""
+ alice = f"@alice:{self.hs.hostname}"
+ brian = f"@brian:{self.hs.hostname}"
+ chris = f"@chris:{self.hs.hostname}"
+
+ # Have three users upload fallback keys for two devices.
+ fallback_keys = {
+ alice: {
+ "alice_dev_1": {"alg1:k1": "fallback_key1"},
+ "alice_dev_2": {"alg2:k2": "fallback_key2"},
+ },
+ brian: {
+ "brian_dev_1": {"alg1:k3": "fallback_key3"},
+ "brian_dev_2": {"alg2:k4": "fallback_key4"},
+ },
+ chris: {
+ "chris_dev_1": {"alg1:k5": "fallback_key5"},
+ "chris_dev_2": {"alg2:k6": "fallback_key6"},
+ },
+ }
+
+ for user_id, devices in fallback_keys.items():
+ for device_id, key_dict in devices.items():
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {"fallback_keys": key_dict},
+ )
+ )
+
+ # Each device should have an unused fallback key.
+ for user_id, devices in fallback_keys.items():
+ for device_id in devices:
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+ expected_algorithm_name = f"alg{device_id[-1]}"
+ self.assertEqual(fallback_res, [expected_algorithm_name])
+
+ # Claim the fallback key for one device per user.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {
+ alice: {"alice_dev_1": {"alg1": 1}},
+ brian: {"brian_dev_2": {"alg2": 1}},
+ chris: {"chris_dev_2": {"alg2": 1}},
+ },
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ expected_claims = {
+ alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
+ brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
+ chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
+ }
+ self.assertEqual(
+ claim_res,
+ {"failures": {}, "one_time_keys": expected_claims},
+ )
+
+ for user_id, devices in fallback_keys.items():
+ for device_id in devices:
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+ # Claimed fallback keys should no longer show up as unused.
+ # Unclaimed fallback keys should still be unused.
+ if device_id in expected_claims[user_id]:
+ self.assertEqual(fallback_res, [])
+ else:
+ expected_algorithm_name = f"alg{device_id[-1]}"
+ self.assertEqual(fallback_res, [expected_algorithm_name])
+
def test_fallback_key_always_returned(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
|