diff --git a/changelog.d/13956.bugfix b/changelog.d/13956.bugfix
new file mode 100644
index 0000000000..5682c3e002
--- /dev/null
+++ b/changelog.d/13956.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a252f8eaa0..b4469eb964 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+# These overloads ensure that `columns` and `iterable` values have the same length.
+# Suppress "Single overload definition, multiple required" complaint.
+@overload # type: ignore[misc]
+def make_tuple_in_list_sql_clause(
+ database_engine: BaseDatabaseEngine,
+ columns: Tuple[str, str],
+ iterable: Collection[Tuple[Any, Any]],
+) -> Tuple[str, list]:
+ ...
+
+
+def make_tuple_in_list_sql_clause(
+ database_engine: BaseDatabaseEngine,
+ columns: Tuple[str, ...],
+ iterable: Collection[Tuple[Any, ...]],
+) -> Tuple[str, list]:
+ """Returns an SQL clause that checks the given tuple of columns is in the iterable.
+
+ Args:
+ database_engine
+ columns: Names of the columns in the tuple.
+ iterable: The tuples to check the columns against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+ if len(columns) == 0:
+ # Should be unreachable due to mypy, as long as the overloads are set up right.
+ if () in iterable:
+ return "TRUE", []
+ else:
+ return "FALSE", []
+
+ if len(columns) == 1:
+ # Use `= ANY(?)` on postgres.
+ return make_in_list_sql_clause(
+ database_engine, next(iter(columns)), [values[0] for values in iterable]
+ )
+
+ # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
+ # indices are not used when there are multiple columns. Instead, use an `IN`
+ # expression.
+ #
+ # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
+ # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
+ # Thus, the latter is chosen.
+
+ if len(iterable) == 0:
+ # A 0-length `VALUES` list is not allowed in sqlite or postgres.
+ # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
+ # allowed in postgres.
+ return "FALSE", []
+
+ tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
+ return "(%s) IN (VALUES %s)" % (
+ ",".join(column for column in columns),
+ ",".join(tuple_sql for _ in iterable),
+ ), [value for values in iterable for value in values]
+
+
KV = TypeVar("KV")
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 8e9e1b0b4b..8a10ae800c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -43,6 +43,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
+ make_tuple_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
@@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _get_e2e_device_keys_txn(
self,
txn: LoggingTransaction,
- query_list: Collection[Tuple[str, str]],
+ query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
- query_clauses = []
- query_params = []
+ query_clauses: List[str] = []
+ query_params_list: List[List[object]] = []
if include_all_devices is False:
include_deleted_devices = False
@@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if include_deleted_devices:
deleted_devices = set(query_list)
+ # Split the query list into queries for users and queries for particular
+ # devices.
+ user_list = []
+ user_device_list = []
for (user_id, device_id) in query_list:
- query_clause = "user_id = ?"
- query_params.append(user_id)
-
- if device_id is not None:
- query_clause += " AND device_id = ?"
- query_params.append(device_id)
-
- query_clauses.append(query_clause)
-
- sql = (
- "SELECT user_id, device_id, "
- " d.display_name, "
- " k.key_json"
- " FROM devices d"
- " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
- " WHERE %s AND NOT d.hidden"
- ) % (
- "LEFT" if include_all_devices else "INNER",
- " OR ".join("(" + q + ")" for q in query_clauses),
- )
+ if device_id is None:
+ user_list.append(user_id)
+ else:
+ user_device_list.append((user_id, device_id))
- txn.execute(sql, query_params)
+ if user_list:
+ user_id_in_list_clause, user_args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", user_list
+ )
+ query_clauses.append(user_id_in_list_clause)
+ query_params_list.append(user_args)
+
+ if user_device_list:
+ # Divide the device queries into batches, to avoid excessively large
+ # queries.
+ for user_device_batch in batch_iter(user_device_list, 1024):
+ (
+ user_device_id_in_list_clause,
+ user_device_args,
+ ) = make_tuple_in_list_sql_clause(
+ txn.database_engine, ("user_id", "device_id"), user_device_batch
+ )
+ query_clauses.append(user_device_id_in_list_clause)
+ query_params_list.append(user_device_args)
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
- for (user_id, device_id, display_name, key_json) in txn:
- if include_deleted_devices:
- deleted_devices.remove((user_id, device_id))
- result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
- display_name, db_to_json(key_json) if key_json else None
+ for query_clause, query_params in zip(query_clauses, query_params_list):
+ sql = (
+ "SELECT user_id, device_id, "
+ " d.display_name, "
+ " k.key_json"
+ " FROM devices d"
+ " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+ " WHERE %s AND NOT d.hidden"
+ ) % (
+ "LEFT" if include_all_devices else "INNER",
+ query_clause,
)
+ txn.execute(sql, query_params)
+
+ for (user_id, device_id, display_name, key_json) in txn:
+ assert device_id is not None
+ if include_deleted_devices:
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, db_to_json(key_json) if key_json else None
+ )
+
if include_deleted_devices:
for user_id, device_id in deleted_devices:
+ if device_id is None:
+ continue
result.setdefault(user_id, {})[device_id] = None
return result
|