summary refs log tree commit diff
path: root/tests/handlers/test_e2e_keys.py
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2023-10-30 14:34:37 +0000
committerGitHub <noreply@github.com>2023-10-30 14:34:37 +0000
commitfdce83ee60b3b5ffd0c41d112873a4de52b1e640 (patch)
treef25decfeeb127373476ec32798cea295b12506b2 /tests/handlers/test_e2e_keys.py
parentBump setuptools-rust from 1.7.0 to 1.8.0 (#16574) (diff)
downloadsynapse-fdce83ee60b3b5ffd0c41d112873a4de52b1e640.tar.xz
Claim fallback keys in bulk (#16570)
Diffstat (limited to 'tests/handlers/test_e2e_keys.py')
-rw-r--r--tests/handlers/test_e2e_keys.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index c5556f2844..24e405f429 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
+    def test_fallback_key_bulk(self) -> None:
+        """Like test_fallback_key, but claims multiple keys in one handler call."""
+        alice = f"@alice:{self.hs.hostname}"
+        brian = f"@brian:{self.hs.hostname}"
+        chris = f"@chris:{self.hs.hostname}"
+
+        # Have three users upload fallback keys for two devices.
+        fallback_keys = {
+            alice: {
+                "alice_dev_1": {"alg1:k1": "fallback_key1"},
+                "alice_dev_2": {"alg2:k2": "fallback_key2"},
+            },
+            brian: {
+                "brian_dev_1": {"alg1:k3": "fallback_key3"},
+                "brian_dev_2": {"alg2:k4": "fallback_key4"},
+            },
+            chris: {
+                "chris_dev_1": {"alg1:k5": "fallback_key5"},
+                "chris_dev_2": {"alg2:k6": "fallback_key6"},
+            },
+        }
+
+        for user_id, devices in fallback_keys.items():
+            for device_id, key_dict in devices.items():
+                self.get_success(
+                    self.handler.upload_keys_for_user(
+                        user_id,
+                        device_id,
+                        {"fallback_keys": key_dict},
+                    )
+                )
+
+        # Each device should have an unused fallback key.
+        for user_id, devices in fallback_keys.items():
+            for device_id in devices:
+                fallback_res = self.get_success(
+                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+                )
+                expected_algorithm_name = f"alg{device_id[-1]}"
+                self.assertEqual(fallback_res, [expected_algorithm_name])
+
+        # Claim the fallback key for one device per user.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {
+                    alice: {"alice_dev_1": {"alg1": 1}},
+                    brian: {"brian_dev_2": {"alg2": 1}},
+                    chris: {"chris_dev_2": {"alg2": 1}},
+                },
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        expected_claims = {
+            alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
+            brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
+            chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
+        }
+        self.assertEqual(
+            claim_res,
+            {"failures": {}, "one_time_keys": expected_claims},
+        )
+
+        for user_id, devices in fallback_keys.items():
+            for device_id in devices:
+                fallback_res = self.get_success(
+                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+                )
+                # Claimed fallback keys should no longer show up as unused.
+                # Unclaimed fallback keys should still be unused.
+                if device_id in expected_claims[user_id]:
+                    self.assertEqual(fallback_res, [])
+                else:
+                    expected_algorithm_name = f"alg{device_id[-1]}"
+                    self.assertEqual(fallback_res, [expected_algorithm_name])
+
     def test_fallback_key_always_returned(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"