diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1a4ae55304..4bc391f213 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+ self, query_list: Iterable[Tuple[str, str, str, int]]
+ ) -> Tuple[
+ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+ ]:
"""Take a list of one time keys out of the database.
Args:
@@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@trace
def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
- ) -> Optional[Tuple[str, str]]:
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
Returns:
@@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT 1
+ LIMIT ?
"""
- txn.execute(sql, (user_id, device_id, algorithm))
- otk_row = txn.fetchone()
- if otk_row is None:
- return None
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
- key_id, key_json = otk_row
-
- self.db_pool.simple_delete_one_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
- "key_id": key_id,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return f"{algorithm}:{key_id}", key_json
+ return [
+ (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+ ]
@trace
def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
- ) -> Optional[Tuple[str, str]]:
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
Returns:
@@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT 1
+ LIMIT ?
)
RETURNING key_id, key_json
"""
txn.execute(
- sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
+ sql,
+ (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
- otk_row = txn.fetchone()
- if otk_row is None:
- return None
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- key_id, key_json = otk_row
- return f"{algorithm}:{key_id}", key_json
+ return [
+ (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+ ]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
- missing: List[Tuple[str, str, str]] = []
- for user_id, device_id, algorithm in query_list:
+ missing: List[Tuple[str, str, str, int]] = []
+ for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
@@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
- claim_row = await self.db_pool.runInteraction(
+ claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
device_id,
algorithm,
+ count,
db_autocommit=db_autocommit,
)
- if claim_row:
+ if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
- else:
- missing.append((user_id, device_id, algorithm))
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
|