summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/e2e_keys.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py253
2 files changed, 149 insertions, 114 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 91c5fe007d..d340d4aebe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -753,6 +753,16 @@ class E2eKeysHandler:
     async def upload_keys_for_user(
         self, user_id: str, device_id: str, keys: JsonDict
     ) -> JsonDict:
+        """
+        Args:
+            user_id: user whose keys are being uploaded.
+            device_id: device whose keys are being uploaded.
+            keys: the body of a /keys/upload request.
+
+        Returns a dictionary with one field:
+            "one_time_keys": A mapping from algorithm to number of keys for that
+                algorithm, including those previously persisted.
+        """
         # This can only be called from the main process.
         assert isinstance(self.device_handler, DeviceHandler)
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 08385d312f..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         ...
 
     async def claim_e2e_one_time_keys(
-        self, query_list: Iterable[Tuple[str, str, str, int]]
+        self, query_list: Collection[Tuple[str, str, str, int]]
     ) -> Tuple[
         Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
     ]:
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             query_list: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
-            A tuple pf:
+            A tuple (results, missing) of:
                 A map of user ID -> a map device ID -> a map of key ID -> JSON.
 
-                A copy of the input which has not been fulfilled.
+                A copy of the input which has not been fulfilled. The returned counts
+                may be less than the input counts. In this case, the returned counts
+                are the number of claims that were not fulfilled.
         """
-
-        @trace
-        def _claim_e2e_one_time_key_simple(
-            txn: LoggingTransaction,
-            user_id: str,
-            device_id: str,
-            algorithm: str,
-            count: int,
-        ) -> List[Tuple[str, str]]:
-            """Claim OTK for device for DBs that don't support RETURNING.
-
-            Returns:
-                A tuple of key name (algorithm + key ID) and key JSON, if an
-                OTK was found.
-            """
-
-            sql = """
-                SELECT key_id, key_json FROM e2e_one_time_keys_json
-                WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                LIMIT ?
-            """
-
-            txn.execute(sql, (user_id, device_id, algorithm, count))
-            otk_rows = list(txn)
-            if not otk_rows:
-                return []
-
-            self.db_pool.simple_delete_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                column="key_id",
-                values=[otk_row[0] for otk_row in otk_rows],
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "algorithm": algorithm,
-                },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-            return [
-                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
-            ]
-
-        @trace
-        def _claim_e2e_one_time_key_returning(
-            txn: LoggingTransaction,
-            user_id: str,
-            device_id: str,
-            algorithm: str,
-            count: int,
-        ) -> List[Tuple[str, str]]:
-            """Claim OTK for device for DBs that support RETURNING.
-
-            Returns:
-                A tuple of key name (algorithm + key ID) and key JSON, if an
-                OTK was found.
-            """
-
-            # We can use RETURNING to do the fetch and DELETE in once step.
-            sql = """
-                DELETE FROM e2e_one_time_keys_json
-                WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                    AND key_id IN (
-                        SELECT key_id FROM e2e_one_time_keys_json
-                        WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                        LIMIT ?
-                    )
-                RETURNING key_id, key_json
-            """
-
-            txn.execute(
-                sql,
-                (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
-            )
-            otk_rows = list(txn)
-            if not otk_rows:
-                return []
-
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
-            return [
-                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
-            ]
-
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         missing: List[Tuple[str, str, str, int]] = []
-        for user_id, device_id, algorithm, count in query_list:
-            if self.database_engine.supports_returning:
-                # If we support RETURNING clause we can use a single query that
-                # allows us to use autocommit mode.
-                _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
-                db_autocommit = True
-            else:
-                _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
-                db_autocommit = False
+        if isinstance(self.database_engine, PostgresEngine):
+            # If we can use execute_values we can use a single batch query
+            # in autocommit mode.
+            unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+            for user_id, device_id, algorithm, count in query_list:
+                unfulfilled_claim_counts[user_id, device_id, algorithm] = count
 
-            claim_rows = await self.db_pool.runInteraction(
+            bulk_claims = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
-                _claim_e2e_one_time_key,
-                user_id,
-                device_id,
-                algorithm,
-                count,
-                db_autocommit=db_autocommit,
+                self._claim_e2e_one_time_keys_bulk,
+                query_list,
+                db_autocommit=True,
             )
-            if claim_rows:
+
+            for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                for claim_row in claim_rows:
-                    device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+                device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+                unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
             # Did we get enough OTKs?
-            count -= len(claim_rows)
-            if count:
-                missing.append((user_id, device_id, algorithm, count))
+            missing = [
+                (user, device, alg, count)
+                for (user, device, alg), count in unfulfilled_claim_counts.items()
+                if count > 0
+            ]
+        else:
+            for user_id, device_id, algorithm, count in query_list:
+                claim_rows = await self.db_pool.runInteraction(
+                    "claim_e2e_one_time_keys",
+                    self._claim_e2e_one_time_key_simple,
+                    user_id,
+                    device_id,
+                    algorithm,
+                    count,
+                    db_autocommit=False,
+                )
+                if claim_rows:
+                    device_results = results.setdefault(user_id, {}).setdefault(
+                        device_id, {}
+                    )
+                    for claim_row in claim_rows:
+                        device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+                # Did we get enough OTKs?
+                count -= len(claim_rows)
+                if count:
+                    missing.append((user_id, device_id, algorithm, count))
 
         return results, missing
 
@@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return results
 
+    @trace
+    def _claim_e2e_one_time_key_simple(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        algorithm: str,
+        count: int,
+    ) -> List[Tuple[str, str]]:
+        """Claim OTK for device for DBs that don't support RETURNING.
+
+        Returns:
+            A tuple of key name (algorithm + key ID) and key JSON, if an
+            OTK was found.
+        """
+
+        sql = """
+            SELECT key_id, key_json FROM e2e_one_time_keys_json
+            WHERE user_id = ? AND device_id = ? AND algorithm = ?
+            LIMIT ?
+        """
+
+        txn.execute(sql, (user_id, device_id, algorithm, count))
+        otk_rows = list(txn)
+        if not otk_rows:
+            return []
+
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            values=[otk_row[0] for otk_row in otk_rows],
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+                "algorithm": algorithm,
+            },
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.count_e2e_one_time_keys, (user_id, device_id)
+        )
+
+        return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+    @trace
+    def _claim_e2e_one_time_keys_bulk(
+        self,
+        txn: LoggingTransaction,
+        query_list: Iterable[Tuple[str, str, str, int]],
+    ) -> List[Tuple[str, str, str, str, str]]:
+        """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+        Args:
+            query_list: Collection of tuples (user_id, device_id, algorithm, count)
+                as passed to claim_e2e_one_time_keys.
+
+        Returns:
+            A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+            for each OTK claimed.
+        """
+        sql = """
+            WITH claims(user_id, device_id, algorithm, claim_count) AS (
+                VALUES ?
+            ), ranked_keys AS (
+                SELECT
+                    user_id, device_id, algorithm, key_id, claim_count,
+                    ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+                FROM e2e_one_time_keys_json
+                    JOIN claims USING (user_id, device_id, algorithm)
+            )
+            DELETE FROM e2e_one_time_keys_json k
+            WHERE (user_id, device_id, algorithm, key_id) IN (
+                SELECT user_id, device_id, algorithm, key_id
+                FROM ranked_keys
+                WHERE r <= claim_count
+            )
+            RETURNING user_id, device_id, algorithm, key_id, key_json;
+        """
+        otk_rows = cast(
+            List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+        )
+
+        seen_user_device: Set[Tuple[str, str]] = set()
+        for user_id, device_id, _, _, _ in otk_rows:
+            if (user_id, device_id) in seen_user_device:
+                continue
+            seen_user_device.add((user_id, device_id))
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id)
+            )
+
+        return otk_rows
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def __init__(