diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2023-04-27 12:57:46 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-27 12:57:46 -0400 |
commit | 57aeeb308b39c4fd455682966eabc9c0fa17c65d (patch) | |
tree | 3b59e2a367f7894a2adfca66c6579fe317723a39 /synapse/storage/databases | |
parent | Add type hints to schema deltas (#15497) (diff) | |
download | synapse-57aeeb308b39c4fd455682966eabc9c0fa17c65d.tar.xz |
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices, this extends this concept to the Client-Server API. Note that this will likely be spit out into a separate MSC, but is currently part of MSC3983.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 77 |
1 files changed, 47 insertions, 30 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1a4ae55304..4bc391f213 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ... async def claim_e2e_one_time_keys( - self, query_list: Iterable[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, query_list: Iterable[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Take a list of one time keys out of the database. Args: @@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @trace def _claim_e2e_one_time_key_simple( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. Returns: @@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? """ - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is None: - return None + txn.execute(sql, (user_id, device_id, algorithm, count)) + otk_rows = list(txn) + if not otk_rows: + return [] - key_id, key_json = otk_row - - self.db_pool.simple_delete_one_txn( + self.db_pool.simple_delete_many_txn( txn, table="e2e_one_time_keys_json", + column="key_id", + values=[otk_row[0] for otk_row in otk_rows], keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, - "key_id": key_id, }, ) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] @trace def _claim_e2e_one_time_key_returning( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. Returns: @@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker AND key_id IN ( SELECT key_id FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? ) RETURNING key_id, key_json """ txn.execute( - sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) + sql, + (user_id, device_id, algorithm, user_id, device_id, algorithm, count), ) - otk_row = txn.fetchone() - if otk_row is None: - return None + otk_rows = list(txn) + if not otk_rows: + return [] self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - key_id, key_json = otk_row - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} - missing: List[Tuple[str, str, str]] = [] - for user_id, device_id, algorithm in query_list: + missing: List[Tuple[str, str, str, int]] = [] + for user_id, device_id, algorithm, count in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that # allows us to use autocommit mode. @@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - claim_row = await self.db_pool.runInteraction( + claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, device_id, algorithm, + count, db_autocommit=db_autocommit, ) - if claim_row: + if claim_rows: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) - else: - missing.append((user_id, device_id, algorithm)) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) return results, missing |