summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py83
1 files changed, 54 insertions, 29 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 8e9e1b0b4b..8a10ae800c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -43,6 +43,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
     make_in_list_sql_clause,
+    make_tuple_in_list_sql_clause,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
@@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     def _get_e2e_device_keys_txn(
         self,
         txn: LoggingTransaction,
-        query_list: Collection[Tuple[str, str]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         cross-signing signatures which have been added subsequently (for which, see
         get_e2e_device_keys_and_signatures)
         """
-        query_clauses = []
-        query_params = []
+        query_clauses: List[str] = []
+        query_params_list: List[List[object]] = []
 
         if include_all_devices is False:
             include_deleted_devices = False
@@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         if include_deleted_devices:
             deleted_devices = set(query_list)
 
+        # Split the query list into queries for users and queries for particular
+        # devices.
+        user_list = []
+        user_device_list = []
         for (user_id, device_id) in query_list:
-            query_clause = "user_id = ?"
-            query_params.append(user_id)
-
-            if device_id is not None:
-                query_clause += " AND device_id = ?"
-                query_params.append(device_id)
-
-            query_clauses.append(query_clause)
-
-        sql = (
-            "SELECT user_id, device_id, "
-            "    d.display_name, "
-            "    k.key_json"
-            " FROM devices d"
-            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
-            " WHERE %s AND NOT d.hidden"
-        ) % (
-            "LEFT" if include_all_devices else "INNER",
-            " OR ".join("(" + q + ")" for q in query_clauses),
-        )
+            if device_id is None:
+                user_list.append(user_id)
+            else:
+                user_device_list.append((user_id, device_id))
 
-        txn.execute(sql, query_params)
+        if user_list:
+            user_id_in_list_clause, user_args = make_in_list_sql_clause(
+                txn.database_engine, "user_id", user_list
+            )
+            query_clauses.append(user_id_in_list_clause)
+            query_params_list.append(user_args)
+
+        if user_device_list:
+            # Divide the device queries into batches, to avoid excessively large
+            # queries.
+            for user_device_batch in batch_iter(user_device_list, 1024):
+                (
+                    user_device_id_in_list_clause,
+                    user_device_args,
+                ) = make_tuple_in_list_sql_clause(
+                    txn.database_engine, ("user_id", "device_id"), user_device_batch
+                )
+                query_clauses.append(user_device_id_in_list_clause)
+                query_params_list.append(user_device_args)
 
         result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
-        for (user_id, device_id, display_name, key_json) in txn:
-            if include_deleted_devices:
-                deleted_devices.remove((user_id, device_id))
-            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
-                display_name, db_to_json(key_json) if key_json else None
+        for query_clause, query_params in zip(query_clauses, query_params_list):
+            sql = (
+                "SELECT user_id, device_id, "
+                "    d.display_name, "
+                "    k.key_json"
+                " FROM devices d"
+                "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+                " WHERE %s AND NOT d.hidden"
+            ) % (
+                "LEFT" if include_all_devices else "INNER",
+                query_clause,
             )
 
+            txn.execute(sql, query_params)
+
+            for (user_id, device_id, display_name, key_json) in txn:
+                assert device_id is not None
+                if include_deleted_devices:
+                    deleted_devices.remove((user_id, device_id))
+                result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+                    display_name, db_to_json(key_json) if key_json else None
+                )
+
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
+                if device_id is None:
+                    continue
                 result.setdefault(user_id, {})[device_id] = None
 
         return result