summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-27 12:57:46 -0400
committerGitHub <noreply@github.com>2023-04-27 12:57:46 -0400
commit57aeeb308b39c4fd455682966eabc9c0fa17c65d (patch)
tree3b59e2a367f7894a2adfca66c6579fe317723a39 /synapse/storage
parentAdd type hints to schema deltas (#15497) (diff)
downloadsynapse-57aeeb308b39c4fd455682966eabc9c0fa17c65d.tar.xz
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py77
1 files changed, 47 insertions, 30 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1a4ae55304..4bc391f213 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         ...
 
     async def claim_e2e_one_time_keys(
-        self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        self, query_list: Iterable[Tuple[str, str, str, int]]
+    ) -> Tuple[
+        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+    ]:
         """Take a list of one time keys out of the database.
 
         Args:
@@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         @trace
         def _claim_e2e_one_time_key_simple(
-            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
-        ) -> Optional[Tuple[str, str]]:
+            txn: LoggingTransaction,
+            user_id: str,
+            device_id: str,
+            algorithm: str,
+            count: int,
+        ) -> List[Tuple[str, str]]:
             """Claim OTK for device for DBs that don't support RETURNING.
 
             Returns:
@@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             sql = """
                 SELECT key_id, key_json FROM e2e_one_time_keys_json
                 WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                LIMIT 1
+                LIMIT ?
             """
 
-            txn.execute(sql, (user_id, device_id, algorithm))
-            otk_row = txn.fetchone()
-            if otk_row is None:
-                return None
+            txn.execute(sql, (user_id, device_id, algorithm, count))
+            otk_rows = list(txn)
+            if not otk_rows:
+                return []
 
-            key_id, key_json = otk_row
-
-            self.db_pool.simple_delete_one_txn(
+            self.db_pool.simple_delete_many_txn(
                 txn,
                 table="e2e_one_time_keys_json",
+                column="key_id",
+                values=[otk_row[0] for otk_row in otk_rows],
                 keyvalues={
                     "user_id": user_id,
                     "device_id": device_id,
                     "algorithm": algorithm,
-                    "key_id": key_id,
                 },
             )
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-            return f"{algorithm}:{key_id}", key_json
+            return [
+                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+            ]
 
         @trace
         def _claim_e2e_one_time_key_returning(
-            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
-        ) -> Optional[Tuple[str, str]]:
+            txn: LoggingTransaction,
+            user_id: str,
+            device_id: str,
+            algorithm: str,
+            count: int,
+        ) -> List[Tuple[str, str]]:
             """Claim OTK for device for DBs that support RETURNING.
 
             Returns:
@@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     AND key_id IN (
                         SELECT key_id FROM e2e_one_time_keys_json
                         WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                        LIMIT 1
+                        LIMIT ?
                     )
                 RETURNING key_id, key_json
             """
 
             txn.execute(
-                sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
+                sql,
+                (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
             )
-            otk_row = txn.fetchone()
-            if otk_row is None:
-                return None
+            otk_rows = list(txn)
+            if not otk_rows:
+                return []
 
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-            key_id, key_json = otk_row
-            return f"{algorithm}:{key_id}", key_json
+            return [
+                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+            ]
 
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
-        missing: List[Tuple[str, str, str]] = []
-        for user_id, device_id, algorithm in query_list:
+        missing: List[Tuple[str, str, str, int]] = []
+        for user_id, device_id, algorithm, count in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
                 # allows us to use autocommit mode.
@@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            claim_row = await self.db_pool.runInteraction(
+            claim_rows = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
                 device_id,
                 algorithm,
+                count,
                 db_autocommit=db_autocommit,
             )
-            if claim_row:
+            if claim_rows:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
-            else:
-                missing.append((user_id, device_id, algorithm))
+                for claim_row in claim_rows:
+                    device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+            # Did we get enough OTKs?
+            count -= len(claim_rows)
+            if count:
+                missing.append((user_id, device_id, algorithm, count))
 
         return results, missing