diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 91c5fe007d..d340d4aebe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -753,6 +753,16 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ keys: the body of a /keys/upload request.
+
+ Returns a dictionary with one field:
+ "one_time_keys": A mapping from algorithm to number of keys for that
+ algorithm, including those previously persisted.
+ """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 08385d312f..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str, int]]
+ self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A tuple pf:
+ A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
- A copy of the input which has not been fulfilled.
+ A copy of the input which has not been fulfilled. The returned counts
+ may be less than the input counts. In this case, the returned counts
+ are the number of claims that were not fulfilled.
"""
-
- @trace
- def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that don't support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- sql = """
- SELECT key_id, key_json FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- """
-
- txn.execute(sql, (user_id, device_id, algorithm, count))
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self.db_pool.simple_delete_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- column="key_id",
- values=[otk_row[0] for otk_row in otk_rows],
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- },
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
- @trace
- def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- # We can use RETURNING to do the fetch and DELETE in once step.
- sql = """
- DELETE FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- AND key_id IN (
- SELECT key_id FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- )
- RETURNING key_id, key_json
- """
-
- txn.execute(
- sql,
- (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
- )
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
- for user_id, device_id, algorithm, count in query_list:
- if self.database_engine.supports_returning:
- # If we support RETURNING clause we can use a single query that
- # allows us to use autocommit mode.
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
- db_autocommit = True
- else:
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
- db_autocommit = False
+ if isinstance(self.database_engine, PostgresEngine):
+ # If we can use execute_values we can use a single batch query
+ # in autocommit mode.
+ unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+ for user_id, device_id, algorithm, count in query_list:
+ unfulfilled_claim_counts[user_id, device_id, algorithm] = count
- claim_rows = await self.db_pool.runInteraction(
+ bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
- _claim_e2e_one_time_key,
- user_id,
- device_id,
- algorithm,
- count,
- db_autocommit=db_autocommit,
+ self._claim_e2e_one_time_keys_bulk,
+ query_list,
+ db_autocommit=True,
)
- if claim_rows:
+
+ for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- for claim_row in claim_rows:
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+ unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
# Did we get enough OTKs?
- count -= len(claim_rows)
- if count:
- missing.append((user_id, device_id, algorithm, count))
+ missing = [
+ (user, device, alg, count)
+ for (user, device, alg), count in unfulfilled_claim_counts.items()
+ if count > 0
+ ]
+ else:
+ for user_id, device_id, algorithm, count in query_list:
+ claim_rows = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ self._claim_e2e_one_time_key_simple,
+ user_id,
+ device_id,
+ algorithm,
+ count,
+ db_autocommit=False,
+ )
+ if claim_rows:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
@@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
+ @trace
+ def _claim_e2e_one_time_key_simple(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+ @trace
+ def _claim_e2e_one_time_keys_bulk(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, int]],
+ ) -> List[Tuple[str, str, str, str, str]]:
+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+ Args:
+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
+ as passed to claim_e2e_one_time_keys.
+
+ Returns:
+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+ for each OTK claimed.
+ """
+ sql = """
+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
+ VALUES ?
+ ), ranked_keys AS (
+ SELECT
+ user_id, device_id, algorithm, key_id, claim_count,
+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+ FROM e2e_one_time_keys_json
+ JOIN claims USING (user_id, device_id, algorithm)
+ )
+ DELETE FROM e2e_one_time_keys_json k
+ WHERE (user_id, device_id, algorithm, key_id) IN (
+ SELECT user_id, device_id, algorithm, key_id
+ FROM ranked_keys
+ WHERE r <= claim_count
+ )
+ RETURNING user_id, device_id, algorithm, key_id, key_json;
+ """
+ otk_rows = cast(
+ List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, _, _, _ in otk_rows:
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return otk_rows
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
|