summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2023-10-30 21:25:21 +0000
committerGitHub <noreply@github.com>2023-10-30 21:25:21 +0000
commitde981ae56720b06c33203e9edd3df08376afb907 (patch)
treef4816898bf1809e855ee2507dec6ec9542e88888
parentClients link fixed in README (#16569) (diff)
downloadsynapse-de981ae56720b06c33203e9edd3df08376afb907.tar.xz
Claim local one-time-keys in bulk (#16565)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
-rw-r--r--changelog.d/16565.feature1
-rw-r--r--synapse/handlers/e2e_keys.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py253
-rw-r--r--tests/handlers/test_e2e_keys.py158
4 files changed, 308 insertions, 114 deletions
diff --git a/changelog.d/16565.feature b/changelog.d/16565.feature
new file mode 100644
index 0000000000..c807945fa8
--- /dev/null
+++ b/changelog.d/16565.feature
@@ -0,0 +1 @@
+Improve the performance of claiming encryption keys.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 91c5fe007d..d340d4aebe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -753,6 +753,16 @@ class E2eKeysHandler:
     async def upload_keys_for_user(
         self, user_id: str, device_id: str, keys: JsonDict
     ) -> JsonDict:
+        """
+        Args:
+            user_id: user whose keys are being uploaded.
+            device_id: device whose keys are being uploaded.
+            keys: the body of a /keys/upload request.
+
+        Returns a dictionary with one field:
+            "one_time_keys": A mapping from algorithm to number of keys for that
+                algorithm, including those previously persisted.
+        """
         # This can only be called from the main process.
         assert isinstance(self.device_handler, DeviceHandler)
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 08385d312f..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         ...
 
     async def claim_e2e_one_time_keys(
-        self, query_list: Iterable[Tuple[str, str, str, int]]
+        self, query_list: Collection[Tuple[str, str, str, int]]
     ) -> Tuple[
         Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
     ]:
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             query_list: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
-            A tuple pf:
+            A tuple (results, missing) of:
                 A map of user ID -> a map device ID -> a map of key ID -> JSON.
 
-                A copy of the input which has not been fulfilled.
+                A copy of the input which has not been fulfilled. The returned counts
+                may be less than the input counts. In this case, the returned counts
+                are the number of claims that were not fulfilled.
         """
-
-        @trace
-        def _claim_e2e_one_time_key_simple(
-            txn: LoggingTransaction,
-            user_id: str,
-            device_id: str,
-            algorithm: str,
-            count: int,
-        ) -> List[Tuple[str, str]]:
-            """Claim OTK for device for DBs that don't support RETURNING.
-
-            Returns:
-                A tuple of key name (algorithm + key ID) and key JSON, if an
-                OTK was found.
-            """
-
-            sql = """
-                SELECT key_id, key_json FROM e2e_one_time_keys_json
-                WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                LIMIT ?
-            """
-
-            txn.execute(sql, (user_id, device_id, algorithm, count))
-            otk_rows = list(txn)
-            if not otk_rows:
-                return []
-
-            self.db_pool.simple_delete_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                column="key_id",
-                values=[otk_row[0] for otk_row in otk_rows],
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "algorithm": algorithm,
-                },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-            return [
-                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
-            ]
-
-        @trace
-        def _claim_e2e_one_time_key_returning(
-            txn: LoggingTransaction,
-            user_id: str,
-            device_id: str,
-            algorithm: str,
-            count: int,
-        ) -> List[Tuple[str, str]]:
-            """Claim OTK for device for DBs that support RETURNING.
-
-            Returns:
-                A tuple of key name (algorithm + key ID) and key JSON, if an
-                OTK was found.
-            """
-
-            # We can use RETURNING to do the fetch and DELETE in once step.
-            sql = """
-                DELETE FROM e2e_one_time_keys_json
-                WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                    AND key_id IN (
-                        SELECT key_id FROM e2e_one_time_keys_json
-                        WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                        LIMIT ?
-                    )
-                RETURNING key_id, key_json
-            """
-
-            txn.execute(
-                sql,
-                (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
-            )
-            otk_rows = list(txn)
-            if not otk_rows:
-                return []
-
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-            return [
-                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
-            ]
-
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         missing: List[Tuple[str, str, str, int]] = []
-        for user_id, device_id, algorithm, count in query_list:
-            if self.database_engine.supports_returning:
-                # If we support RETURNING clause we can use a single query that
-                # allows us to use autocommit mode.
-                _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
-                db_autocommit = True
-            else:
-                _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
-                db_autocommit = False
+        if isinstance(self.database_engine, PostgresEngine):
+            # If we can use execute_values we can use a single batch query
+            # in autocommit mode.
+            unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+            for user_id, device_id, algorithm, count in query_list:
+                unfulfilled_claim_counts[user_id, device_id, algorithm] = count
 
-            claim_rows = await self.db_pool.runInteraction(
+            bulk_claims = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
-                _claim_e2e_one_time_key,
-                user_id,
-                device_id,
-                algorithm,
-                count,
-                db_autocommit=db_autocommit,
+                self._claim_e2e_one_time_keys_bulk,
+                query_list,
+                db_autocommit=True,
             )
-            if claim_rows:
+
+            for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                for claim_row in claim_rows:
-                    device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+                device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+                unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
             # Did we get enough OTKs?
-            count -= len(claim_rows)
-            if count:
-                missing.append((user_id, device_id, algorithm, count))
+            missing = [
+                (user, device, alg, count)
+                for (user, device, alg), count in unfulfilled_claim_counts.items()
+                if count > 0
+            ]
+        else:
+            for user_id, device_id, algorithm, count in query_list:
+                claim_rows = await self.db_pool.runInteraction(
+                    "claim_e2e_one_time_keys",
+                    self._claim_e2e_one_time_key_simple,
+                    user_id,
+                    device_id,
+                    algorithm,
+                    count,
+                    db_autocommit=False,
+                )
+                if claim_rows:
+                    device_results = results.setdefault(user_id, {}).setdefault(
+                        device_id, {}
+                    )
+                    for claim_row in claim_rows:
+                        device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+                # Did we get enough OTKs?
+                count -= len(claim_rows)
+                if count:
+                    missing.append((user_id, device_id, algorithm, count))
 
         return results, missing
 
@@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return results
 
+    @trace
+    def _claim_e2e_one_time_key_simple(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        algorithm: str,
+        count: int,
+    ) -> List[Tuple[str, str]]:
+        """Claim OTK for device for DBs that don't support RETURNING.
+
+        Returns:
+            A tuple of key name (algorithm + key ID) and key JSON, if an
+            OTK was found.
+        """
+
+        sql = """
+            SELECT key_id, key_json FROM e2e_one_time_keys_json
+            WHERE user_id = ? AND device_id = ? AND algorithm = ?
+            LIMIT ?
+        """
+
+        txn.execute(sql, (user_id, device_id, algorithm, count))
+        otk_rows = list(txn)
+        if not otk_rows:
+            return []
+
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            values=[otk_row[0] for otk_row in otk_rows],
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+                "algorithm": algorithm,
+            },
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.count_e2e_one_time_keys, (user_id, device_id)
+        )
+
+        return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+    @trace
+    def _claim_e2e_one_time_keys_bulk(
+        self,
+        txn: LoggingTransaction,
+        query_list: Iterable[Tuple[str, str, str, int]],
+    ) -> List[Tuple[str, str, str, str, str]]:
+        """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+        Args:
+            query_list: Collection of tuples (user_id, device_id, algorithm, count)
+                as passed to claim_e2e_one_time_keys.
+
+        Returns:
+            A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+            for each OTK claimed.
+        """
+        sql = """
+            WITH claims(user_id, device_id, algorithm, claim_count) AS (
+                VALUES ?
+            ), ranked_keys AS (
+                SELECT
+                    user_id, device_id, algorithm, key_id, claim_count,
+                    ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+                FROM e2e_one_time_keys_json
+                    JOIN claims USING (user_id, device_id, algorithm)
+            )
+            DELETE FROM e2e_one_time_keys_json k
+            WHERE (user_id, device_id, algorithm, key_id) IN (
+                SELECT user_id, device_id, algorithm, key_id
+                FROM ranked_keys
+                WHERE r <= claim_count
+            )
+            RETURNING user_id, device_id, algorithm, key_id, key_json;
+        """
+        otk_rows = cast(
+            List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+        )
+
+        seen_user_device: Set[Tuple[str, str]] = set()
+        for user_id, device_id, _, _, _ in otk_rows:
+            if (user_id, device_id) in seen_user_device:
+                continue
+            seen_user_device.add((user_id, device_id))
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id)
+            )
+
+        return otk_rows
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def __init__(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 24e405f429..90b4da9ad5 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    def test_claim_one_time_key_bulk(self) -> None:
+        """Like test_claim_one_time_key but claims multiple keys in one handler call."""
+        # Apologies to the reader. This test is a little too verbose. It is particularly
+        # tricky to make assertions neatly with all these nested dictionaries in play.
+
+        # Three users with two devices each. Each device uses two algorithms.
+        # Each algorithm is invoked with two keys.
+        alice = f"@alice:{self.hs.hostname}"
+        brian = f"@brian:{self.hs.hostname}"
+        chris = f"@chris:{self.hs.hostname}"
+        one_time_keys = {
+            alice: {
+                "alice_dev_1": {
+                    "alg1:k1": {"dummy_id": 1},
+                    "alg1:k2": {"dummy_id": 2},
+                    "alg2:k3": {"dummy_id": 3},
+                    "alg2:k4": {"dummy_id": 4},
+                },
+                "alice_dev_2": {
+                    "alg1:k5": {"dummy_id": 5},
+                    "alg1:k6": {"dummy_id": 6},
+                    "alg2:k7": {"dummy_id": 7},
+                    "alg2:k8": {"dummy_id": 8},
+                },
+            },
+            brian: {
+                "brian_dev_1": {
+                    "alg1:k9": {"dummy_id": 9},
+                    "alg1:k10": {"dummy_id": 10},
+                    "alg2:k11": {"dummy_id": 11},
+                    "alg2:k12": {"dummy_id": 12},
+                },
+                "brian_dev_2": {
+                    "alg1:k13": {"dummy_id": 13},
+                    "alg1:k14": {"dummy_id": 14},
+                    "alg2:k15": {"dummy_id": 15},
+                    "alg2:k16": {"dummy_id": 16},
+                },
+            },
+            chris: {
+                "chris_dev_1": {
+                    "alg1:k17": {"dummy_id": 17},
+                    "alg1:k18": {"dummy_id": 18},
+                    "alg2:k19": {"dummy_id": 19},
+                    "alg2:k20": {"dummy_id": 20},
+                },
+                "chris_dev_2": {
+                    "alg1:k21": {"dummy_id": 21},
+                    "alg1:k22": {"dummy_id": 22},
+                    "alg2:k23": {"dummy_id": 23},
+                    "alg2:k24": {"dummy_id": 24},
+                },
+            },
+        }
+        for user_id, devices in one_time_keys.items():
+            for device_id, keys_dict in devices.items():
+                counts = self.get_success(
+                    self.handler.upload_keys_for_user(
+                        user_id,
+                        device_id,
+                        {"one_time_keys": keys_dict},
+                    )
+                )
+                # The upload should report 2 keys per algorithm.
+                expected_counts = {
+                    "one_time_key_counts": {
+                        # See count_e2e_one_time_keys for why this is hardcoded.
+                        "signed_curve25519": 0,
+                        "alg1": 2,
+                        "alg2": 2,
+                    },
+                }
+                self.assertEqual(counts, expected_counts)
+
+        # Claim a variety of keys.
+        # Raw format, easier to make test assertions about.
+        claims_to_make = {
+            (alice, "alice_dev_1", "alg1"): 1,
+            (alice, "alice_dev_1", "alg2"): 2,
+            (alice, "alice_dev_2", "alg2"): 1,
+            (brian, "brian_dev_1", "alg1"): 2,
+            (brian, "brian_dev_2", "alg2"): 9001,
+            (chris, "chris_dev_2", "alg2"): 1,
+        }
+        # Convert to the format the handler wants.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for (user_id, device_id, algorithm), count in claims_to_make.items():
+            query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                query,
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+
+        # No failures, please!
+        self.assertEqual(claim_res["failures"], {})
+
+        # Check that we get exactly the (user, device, algorithm)s we asked for.
+        got_otks = claim_res["one_time_keys"]
+        claimed_user_device_algorithms = {
+            (user_id, device_id, alg_key_id.split(":")[0])
+            for user_id, devices in got_otks.items()
+            for device_id, key_dict in devices.items()
+            for alg_key_id in key_dict
+        }
+        self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
+
+        # Now check the keys we got are what we expected.
+        def assertExactlyOneOtk(
+            user_id: str, device_id: str, *alg_key_pairs: str
+        ) -> None:
+            key_dict = got_otks[user_id][device_id]
+            found = 0
+            for alg_key in alg_key_pairs:
+                if alg_key in key_dict:
+                    expected_key_json = one_time_keys[user_id][device_id][alg_key]
+                    self.assertEqual(key_dict[alg_key], expected_key_json)
+                    found += 1
+            self.assertEqual(found, 1)
+
+        def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
+            key_dict = got_otks[user_id][device_id]
+            for alg_key in alg_key_pairs:
+                expected_key_json = one_time_keys[user_id][device_id][alg_key]
+                self.assertEqual(key_dict[alg_key], expected_key_json)
+
+        # Expect a single arbitrary key to be returned.
+        assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
+        assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
+        assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
+
+        assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
+        assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
+        assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
+
+        # Now check the unused key counts.
+        for user_id, devices in one_time_keys.items():
+            for device_id in devices:
+                counts_by_alg = self.get_success(
+                    self.store.count_e2e_one_time_keys(user_id, device_id)
+                )
+                # Somewhat fiddley to compute the expected count dict.
+                expected_counts_by_alg = {
+                    "signed_curve25519": 0,
+                }
+                for alg in ["alg1", "alg2"]:
+                    claim_count = claims_to_make.get((user_id, device_id, alg), 0)
+                    remaining_count = max(0, 2 - claim_count)
+                    if remaining_count > 0:
+                        expected_counts_by_alg[alg] = remaining_count
+
+                self.assertEqual(
+                    counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
+                )
+
     def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"