summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13956.bugfix1
-rw-r--r--synapse/storage/database.py60
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py83
3 files changed, 115 insertions, 29 deletions
diff --git a/changelog.d/13956.bugfix b/changelog.d/13956.bugfix
new file mode 100644
index 0000000000..5682c3e002
--- /dev/null
+++ b/changelog.d/13956.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a252f8eaa0..b4469eb964 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
         return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
 
 
+# These overloads ensure that `columns` and `iterable` values have the same length.
+# Suppress "Single overload definition, multiple required" complaint.
+@overload  # type: ignore[misc]
+def make_tuple_in_list_sql_clause(
+    database_engine: BaseDatabaseEngine,
+    columns: Tuple[str, str],
+    iterable: Collection[Tuple[Any, Any]],
+) -> Tuple[str, list]:
+    ...
+
+
+def make_tuple_in_list_sql_clause(
+    database_engine: BaseDatabaseEngine,
+    columns: Tuple[str, ...],
+    iterable: Collection[Tuple[Any, ...]],
+) -> Tuple[str, list]:
+    """Returns an SQL clause that checks the given tuple of columns is in the iterable.
+
+    Args:
+        database_engine
+        columns: Names of the columns in the tuple.
+        iterable: The tuples to check the columns against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+    if len(columns) == 0:
+        # Should be unreachable due to mypy, as long as the overloads are set up right.
+        if () in iterable:
+            return "TRUE", []
+        else:
+            return "FALSE", []
+
+    if len(columns) == 1:
+        # Use `= ANY(?)` on postgres.
+        return make_in_list_sql_clause(
+            database_engine, next(iter(columns)), [values[0] for values in iterable]
+        )
+
+    # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
+    # indices are not used when there are multiple columns. Instead, use an `IN`
+    # expression.
+    #
+    # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
+    # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
+    # Thus, the latter is chosen.
+
+    if len(iterable) == 0:
+        # A 0-length `VALUES` list is not allowed in sqlite or postgres.
+        # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
+        # allowed in postgres.
+        return "FALSE", []
+
+    tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
+    return "(%s) IN (VALUES %s)" % (
+        ",".join(column for column in columns),
+        ",".join(tuple_sql for _ in iterable),
+    ), [value for values in iterable for value in values]
+
+
 KV = TypeVar("KV")
 
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 8e9e1b0b4b..8a10ae800c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -43,6 +43,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
     make_in_list_sql_clause,
+    make_tuple_in_list_sql_clause,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
@@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     def _get_e2e_device_keys_txn(
         self,
         txn: LoggingTransaction,
-        query_list: Collection[Tuple[str, str]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         cross-signing signatures which have been added subsequently (for which, see
         get_e2e_device_keys_and_signatures)
         """
-        query_clauses = []
-        query_params = []
+        query_clauses: List[str] = []
+        query_params_list: List[List[object]] = []
 
         if include_all_devices is False:
             include_deleted_devices = False
@@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if include_deleted_devices:
             deleted_devices = set(query_list)
 
+        # Split the query list into queries for users and queries for particular
+        # devices.
+        user_list = []
+        user_device_list = []
         for (user_id, device_id) in query_list:
-            query_clause = "user_id = ?"
-            query_params.append(user_id)
-
-            if device_id is not None:
-                query_clause += " AND device_id = ?"
-                query_params.append(device_id)
-
-            query_clauses.append(query_clause)
-
-        sql = (
-            "SELECT user_id, device_id, "
-            "    d.display_name, "
-            "    k.key_json"
-            " FROM devices d"
-            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
-            " WHERE %s AND NOT d.hidden"
-        ) % (
-            "LEFT" if include_all_devices else "INNER",
-            " OR ".join("(" + q + ")" for q in query_clauses),
-        )
+            if device_id is None:
+                user_list.append(user_id)
+            else:
+                user_device_list.append((user_id, device_id))
 
-        txn.execute(sql, query_params)
+        if user_list:
+            user_id_in_list_clause, user_args = make_in_list_sql_clause(
+                txn.database_engine, "user_id", user_list
+            )
+            query_clauses.append(user_id_in_list_clause)
+            query_params_list.append(user_args)
+
+        if user_device_list:
+            # Divide the device queries into batches, to avoid excessively large
+            # queries.
+            for user_device_batch in batch_iter(user_device_list, 1024):
+                (
+                    user_device_id_in_list_clause,
+                    user_device_args,
+                ) = make_tuple_in_list_sql_clause(
+                    txn.database_engine, ("user_id", "device_id"), user_device_batch
+                )
+                query_clauses.append(user_device_id_in_list_clause)
+                query_params_list.append(user_device_args)
 
         result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
-        for (user_id, device_id, display_name, key_json) in txn:
-            if include_deleted_devices:
-                deleted_devices.remove((user_id, device_id))
-            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
-                display_name, db_to_json(key_json) if key_json else None
+        for query_clause, query_params in zip(query_clauses, query_params_list):
+            sql = (
+                "SELECT user_id, device_id, "
+                "    d.display_name, "
+                "    k.key_json"
+                " FROM devices d"
+                "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+                " WHERE %s AND NOT d.hidden"
+            ) % (
+                "LEFT" if include_all_devices else "INNER",
+                query_clause,
             )
 
+            txn.execute(sql, query_params)
+
+            for (user_id, device_id, display_name, key_json) in txn:
+                assert device_id is not None
+                if include_deleted_devices:
+                    deleted_devices.remove((user_id, device_id))
+                result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+                    display_name, db_to_json(key_json) if key_json else None
+                )
+
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
+                if device_id is None:
+                    continue
                 result.setdefault(user_id, {})[device_id] = None
 
         return result