diff options
author | David Robertson <davidr@element.io> | 2023-10-30 21:25:21 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-30 21:25:21 +0000 |
commit | de981ae56720b06c33203e9edd3df08376afb907 (patch) | |
tree | f4816898bf1809e855ee2507dec6ec9542e88888 /synapse/storage/databases/main | |
parent | Clients link fixed in README (#16569) (diff) | |
download | synapse-de981ae56720b06c33203e9edd3df08376afb907.tar.xz |
Claim local one-time-keys in bulk (#16565)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 253 |
1 files changed, 139 insertions, 114 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 08385d312f..4f96ac25c7 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ... async def claim_e2e_one_time_keys( - self, query_list: Iterable[Tuple[str, str, str, int]] + self, query_list: Collection[Tuple[str, str, str, int]] ) -> Tuple[ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] ]: @@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker query_list: An iterable of tuples of (user ID, device ID, algorithm). Returns: - A tuple pf: + A tuple (results, missing) of: A map of user ID -> a map device ID -> a map of key ID -> JSON. - A copy of the input which has not been fulfilled. + A copy of the input which has not been fulfilled. The returned counts + may be less than the input counts. In this case, the returned counts + are the number of claims that were not fulfilled. """ - - @trace - def _claim_e2e_one_time_key_simple( - txn: LoggingTransaction, - user_id: str, - device_id: str, - algorithm: str, - count: int, - ) -> List[Tuple[str, str]]: - """Claim OTK for device for DBs that don't support RETURNING. - - Returns: - A tuple of key name (algorithm + key ID) and key JSON, if an - OTK was found. - """ - - sql = """ - SELECT key_id, key_json FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT ? - """ - - txn.execute(sql, (user_id, device_id, algorithm, count)) - otk_rows = list(txn) - if not otk_rows: - return [] - - self.db_pool.simple_delete_many_txn( - txn, - table="e2e_one_time_keys_json", - column="key_id", - values=[otk_row[0] for otk_row in otk_rows], - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - }, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return [ - (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows - ] - - @trace - def _claim_e2e_one_time_key_returning( - txn: LoggingTransaction, - user_id: str, - device_id: str, - algorithm: str, - count: int, - ) -> List[Tuple[str, str]]: - """Claim OTK for device for DBs that support RETURNING. - - Returns: - A tuple of key name (algorithm + key ID) and key JSON, if an - OTK was found. - """ - - # We can use RETURNING to do the fetch and DELETE in once step. - sql = """ - DELETE FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - AND key_id IN ( - SELECT key_id FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT ? - ) - RETURNING key_id, key_json - """ - - txn.execute( - sql, - (user_id, device_id, algorithm, user_id, device_id, algorithm, count), - ) - otk_rows = list(txn) - if not otk_rows: - return [] - - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return [ - (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows - ] - results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} missing: List[Tuple[str, str, str, int]] = [] - for user_id, device_id, algorithm, count in query_list: - if self.database_engine.supports_returning: - # If we support RETURNING clause we can use a single query that - # allows us to use autocommit mode. - _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning - db_autocommit = True - else: - _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple - db_autocommit = False + if isinstance(self.database_engine, PostgresEngine): + # If we can use execute_values we can use a single batch query + # in autocommit mode. + unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {} + for user_id, device_id, algorithm, count in query_list: + unfulfilled_claim_counts[user_id, device_id, algorithm] = count - claim_rows = await self.db_pool.runInteraction( + bulk_claims = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - _claim_e2e_one_time_key, - user_id, - device_id, - algorithm, - count, - db_autocommit=db_autocommit, + self._claim_e2e_one_time_keys_bulk, + query_list, + db_autocommit=True, ) - if claim_rows: + + for user_id, device_id, algorithm, key_id, key_json in bulk_claims: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - for claim_row in claim_rows: - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) + unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1 + # Did we get enough OTKs? - count -= len(claim_rows) - if count: - missing.append((user_id, device_id, algorithm, count)) + missing = [ + (user, device, alg, count) + for (user, device, alg), count in unfulfilled_claim_counts.items() + if count > 0 + ] + else: + for user_id, device_id, algorithm, count in query_list: + claim_rows = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + self._claim_e2e_one_time_key_simple, + user_id, + device_id, + algorithm, + count, + db_autocommit=False, + ) + if claim_rows: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} + ) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) return results, missing @@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return results + @trace + def _claim_e2e_one_time_key_simple( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT ? + """ + + txn.execute(sql, (user_id, device_id, algorithm, count)) + otk_rows = list(txn) + if not otk_rows: + return [] + + self.db_pool.simple_delete_many_txn( + txn, + table="e2e_one_time_keys_json", + column="key_id", + values=[otk_row[0] for otk_row in otk_rows], + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] + + @trace + def _claim_e2e_one_time_keys_bulk( + self, + txn: LoggingTransaction, + query_list: Iterable[Tuple[str, str, str, int]], + ) -> List[Tuple[str, str, str, str, str]]: + """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING. + + Args: + query_list: Collection of tuples (user_id, device_id, algorithm, count) + as passed to claim_e2e_one_time_keys. + + Returns: + A list of tuples (user_id, device_id, algorithm, key_id, key_json) + for each OTK claimed. + """ + sql = """ + WITH claims(user_id, device_id, algorithm, claim_count) AS ( + VALUES ? + ), ranked_keys AS ( + SELECT + user_id, device_id, algorithm, key_id, claim_count, + ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r + FROM e2e_one_time_keys_json + JOIN claims USING (user_id, device_id, algorithm) + ) + DELETE FROM e2e_one_time_keys_json k + WHERE (user_id, device_id, algorithm, key_id) IN ( + SELECT user_id, device_id, algorithm, key_id + FROM ranked_keys + WHERE r <= claim_count + ) + RETURNING user_id, device_id, algorithm, key_id, key_json; + """ + otk_rows = cast( + List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) + ) + + seen_user_device: Set[Tuple[str, str]] = set() + for user_id, device_id, _, _, _ in otk_rows: + if (user_id, device_id) in seen_user_device: + continue + seen_user_device.add((user_id, device_id)) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return otk_rows + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( |