diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8c6432035d..91c5fe007d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -659,6 +659,20 @@ class E2eKeysHandler:
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
+ """
+ Args:
+ query: A chain of maps from (user_id, device_id, algorithm) to the requested
+ number of keys to claim.
+ user: The user who is claiming these keys.
+ timeout: How long to wait for any federation key claim requests before
+ giving up.
+ always_include_fallback_keys: always include a fallback key for local users'
+ devices, even if we managed to claim a one-time-key.
+
+ Returns: a heterogeneous dict with two keys:
+ one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
+ failures: map from remote destination to a JsonDict describing the error.
+ """
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b1ece63845..a4e7048368 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -420,6 +420,16 @@ class LoggingTransaction:
self._do_execute(self.txn.execute, sql, parameters)
def executemany(self, sql: str, *args: Any) -> None:
+ """Repeatedly execute the same piece of SQL with different parameters.
+
+ See https://peps.python.org/pep-0249/#executemany. Note in particular that
+
+ > Use of this method for an operation which produces one or more result sets
+ > constitutes undefined behavior
+
+ so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
+ DELETE FROM... RETURNING.
+ """
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f70f95eeba..08385d312f 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,6 +24,7 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
+ if isinstance(self.database_engine, PostgresEngine):
+ return await self.db_pool.runInteraction(
+ "_claim_e2e_fallback_keys_bulk",
+ self._claim_e2e_fallback_keys_bulk_txn,
+ query_list,
+ db_autocommit=True,
+ )
+ # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
+ # everything in one query. Note: this is also supported in SQLite 3.33.0,
+ # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
+ # have an equivalent of psycopg2's execute_values to do this in one query.
+ else:
+ return await self._claim_e2e_fallback_keys_simple(query_list)
+
+ def _claim_e2e_fallback_keys_bulk_txn(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Efficient implementation of claim_e2e_fallback_keys for Postgres.
+
+ Safe to autocommit: this is a single query.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+
+ sql = """
+ WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
+ VALUES ?
+ )
+ UPDATE e2e_fallback_keys_json k
+ SET used = used OR mark_as_used
+ FROM claims
+ WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
+ RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
+ """
+ claimed_keys = cast(
+ List[Tuple[str, str, str, str, str]],
+ txn.execute_values(sql, query_list),
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
+ device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
+ return results
+
+ async def _claim_e2e_fallback_keys_simple(
+ self,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
|