summary refs log tree commit diff
path: root/tests/handlers/test_e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_e2e_keys.py')
-rw-r--r--tests/handlers/test_e2e_keys.py158
1 files changed, 158 insertions, 0 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 24e405f429..90b4da9ad5 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    def test_claim_one_time_key_bulk(self) -> None:
+        """Like test_claim_one_time_key but claims multiple keys in one handler call."""
+        # Apologies to the reader. This test is a little too verbose. It is particularly
+        # tricky to make assertions neatly with all these nested dictionaries in play.
+
+        # Three users with two devices each. Each device uses two algorithms.
+        # Each algorithm is invoked with two keys.
+        alice = f"@alice:{self.hs.hostname}"
+        brian = f"@brian:{self.hs.hostname}"
+        chris = f"@chris:{self.hs.hostname}"
+        one_time_keys = {
+            alice: {
+                "alice_dev_1": {
+                    "alg1:k1": {"dummy_id": 1},
+                    "alg1:k2": {"dummy_id": 2},
+                    "alg2:k3": {"dummy_id": 3},
+                    "alg2:k4": {"dummy_id": 4},
+                },
+                "alice_dev_2": {
+                    "alg1:k5": {"dummy_id": 5},
+                    "alg1:k6": {"dummy_id": 6},
+                    "alg2:k7": {"dummy_id": 7},
+                    "alg2:k8": {"dummy_id": 8},
+                },
+            },
+            brian: {
+                "brian_dev_1": {
+                    "alg1:k9": {"dummy_id": 9},
+                    "alg1:k10": {"dummy_id": 10},
+                    "alg2:k11": {"dummy_id": 11},
+                    "alg2:k12": {"dummy_id": 12},
+                },
+                "brian_dev_2": {
+                    "alg1:k13": {"dummy_id": 13},
+                    "alg1:k14": {"dummy_id": 14},
+                    "alg2:k15": {"dummy_id": 15},
+                    "alg2:k16": {"dummy_id": 16},
+                },
+            },
+            chris: {
+                "chris_dev_1": {
+                    "alg1:k17": {"dummy_id": 17},
+                    "alg1:k18": {"dummy_id": 18},
+                    "alg2:k19": {"dummy_id": 19},
+                    "alg2:k20": {"dummy_id": 20},
+                },
+                "chris_dev_2": {
+                    "alg1:k21": {"dummy_id": 21},
+                    "alg1:k22": {"dummy_id": 22},
+                    "alg2:k23": {"dummy_id": 23},
+                    "alg2:k24": {"dummy_id": 24},
+                },
+            },
+        }
+        for user_id, devices in one_time_keys.items():
+            for device_id, keys_dict in devices.items():
+                counts = self.get_success(
+                    self.handler.upload_keys_for_user(
+                        user_id,
+                        device_id,
+                        {"one_time_keys": keys_dict},
+                    )
+                )
+                # The upload should report 2 keys per algorithm.
+                expected_counts = {
+                    "one_time_key_counts": {
+                        # See count_e2e_one_time_keys for why this is hardcoded.
+                        "signed_curve25519": 0,
+                        "alg1": 2,
+                        "alg2": 2,
+                    },
+                }
+                self.assertEqual(counts, expected_counts)
+
+        # Claim a variety of keys.
+        # Raw format, easier to make test assertions about.
+        claims_to_make = {
+            (alice, "alice_dev_1", "alg1"): 1,
+            (alice, "alice_dev_1", "alg2"): 2,
+            (alice, "alice_dev_2", "alg2"): 1,
+            (brian, "brian_dev_1", "alg1"): 2,
+            (brian, "brian_dev_2", "alg2"): 9001,
+            (chris, "chris_dev_2", "alg2"): 1,
+        }
+        # Convert to the format the handler wants.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for (user_id, device_id, algorithm), count in claims_to_make.items():
+            query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                query,
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+
+        # No failures, please!
+        self.assertEqual(claim_res["failures"], {})
+
+        # Check that we get exactly the (user, device, algorithm)s we asked for.
+        got_otks = claim_res["one_time_keys"]
+        claimed_user_device_algorithms = {
+            (user_id, device_id, alg_key_id.split(":")[0])
+            for user_id, devices in got_otks.items()
+            for device_id, key_dict in devices.items()
+            for alg_key_id in key_dict
+        }
+        self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
+
+        # Now check the keys we got are what we expected.
+        def assertExactlyOneOtk(
+            user_id: str, device_id: str, *alg_key_pairs: str
+        ) -> None:
+            key_dict = got_otks[user_id][device_id]
+            found = 0
+            for alg_key in alg_key_pairs:
+                if alg_key in key_dict:
+                    expected_key_json = one_time_keys[user_id][device_id][alg_key]
+                    self.assertEqual(key_dict[alg_key], expected_key_json)
+                    found += 1
+            self.assertEqual(found, 1)
+
+        def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
+            key_dict = got_otks[user_id][device_id]
+            for alg_key in alg_key_pairs:
+                expected_key_json = one_time_keys[user_id][device_id][alg_key]
+                self.assertEqual(key_dict[alg_key], expected_key_json)
+
+        # Expect a single arbitrary key to be returned.
+        assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
+        assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
+        assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
+
+        assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
+        assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
+        assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
+
+        # Now check the unused key counts.
+        for user_id, devices in one_time_keys.items():
+            for device_id in devices:
+                counts_by_alg = self.get_success(
+                    self.store.count_e2e_one_time_keys(user_id, device_id)
+                )
+                # Somewhat fiddley to compute the expected count dict.
+                expected_counts_by_alg = {
+                    "signed_curve25519": 0,
+                }
+                for alg in ["alg1", "alg2"]:
+                    claim_count = claims_to_make.get((user_id, device_id, alg), 0)
+                    remaining_count = max(0, 2 - claim_count)
+                    if remaining_count > 0:
+                        expected_counts_by_alg[alg] = remaining_count
+
+                self.assertEqual(
+                    counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
+                )
+
     def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"