diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/database.py | 60 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 83 |
2 files changed, 114 insertions, 29 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a252f8eaa0..b4469eb964 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2461,6 +2461,66 @@ def make_in_list_sql_clause( return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) +# These overloads ensure that `columns` and `iterable` values have the same length. +# Suppress "Single overload definition, multiple required" complaint. +@overload # type: ignore[misc] +def make_tuple_in_list_sql_clause( + database_engine: BaseDatabaseEngine, + columns: Tuple[str, str], + iterable: Collection[Tuple[Any, Any]], +) -> Tuple[str, list]: + ... + + +def make_tuple_in_list_sql_clause( + database_engine: BaseDatabaseEngine, + columns: Tuple[str, ...], + iterable: Collection[Tuple[Any, ...]], +) -> Tuple[str, list]: + """Returns an SQL clause that checks the given tuple of columns is in the iterable. + + Args: + database_engine + columns: Names of the columns in the tuple. + iterable: The tuples to check the columns against. + + Returns: + A tuple of SQL query and the args + """ + if len(columns) == 0: + # Should be unreachable due to mypy, as long as the overloads are set up right. + if () in iterable: + return "TRUE", [] + else: + return "FALSE", [] + + if len(columns) == 1: + # Use `= ANY(?)` on postgres. + return make_in_list_sql_clause( + database_engine, next(iter(columns)), [values[0] for values in iterable] + ) + + # There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as + # indices are not used when there are multiple columns. Instead, use an `IN` + # expression. + # + # `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas + # `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres. + # Thus, the latter is chosen. + + if len(iterable) == 0: + # A 0-length `VALUES` list is not allowed in sqlite or postgres. + # Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not + # allowed in postgres. + return "FALSE", [] + + tuple_sql = "(%s)" % (",".join("?" for _ in columns),) + return "(%s) IN (VALUES %s)" % ( + ",".join(column for column in columns), + ",".join(tuple_sql for _ in iterable), + ), [value for values in iterable for value in values] + + KV = TypeVar("KV") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8e9e1b0b4b..8a10ae800c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -43,6 +43,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, make_in_list_sql_clause, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine @@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _get_e2e_device_keys_txn( self, txn: LoggingTransaction, - query_list: Collection[Tuple[str, str]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: @@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker cross-signing signatures which have been added subsequently (for which, see get_e2e_device_keys_and_signatures) """ - query_clauses = [] - query_params = [] + query_clauses: List[str] = [] + query_params_list: List[List[object]] = [] if include_all_devices is False: include_deleted_devices = False @@ -297,40 +298,64 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if include_deleted_devices: deleted_devices = set(query_list) + # Split the query list into queries for users and queries for particular + # devices. + user_list = [] + user_device_list = [] for (user_id, device_id) in query_list: - query_clause = "user_id = ?" - query_params.append(user_id) - - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) - - query_clauses.append(query_clause) - - sql = ( - "SELECT user_id, device_id, " - " d.display_name, " - " k.key_json" - " FROM devices d" - " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s AND NOT d.hidden" - ) % ( - "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses), - ) + if device_id is None: + user_list.append(user_id) + else: + user_device_list.append((user_id, device_id)) - txn.execute(sql, query_params) + if user_list: + user_id_in_list_clause, user_args = make_in_list_sql_clause( + txn.database_engine, "user_id", user_list + ) + query_clauses.append(user_id_in_list_clause) + query_params_list.append(user_args) + + if user_device_list: + # Divide the device queries into batches, to avoid excessively large + # queries. + for user_device_batch in batch_iter(user_device_list, 1024): + ( + user_device_id_in_list_clause, + user_device_args, + ) = make_tuple_in_list_sql_clause( + txn.database_engine, ("user_id", "device_id"), user_device_batch + ) + query_clauses.append(user_device_id_in_list_clause) + query_params_list.append(user_device_args) result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} - for (user_id, device_id, display_name, key_json) in txn: - if include_deleted_devices: - deleted_devices.remove((user_id, device_id)) - result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( - display_name, db_to_json(key_json) if key_json else None + for query_clause, query_params in zip(query_clauses, query_params_list): + sql = ( + "SELECT user_id, device_id, " + " d.display_name, " + " k.key_json" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" + " WHERE %s AND NOT d.hidden" + ) % ( + "LEFT" if include_all_devices else "INNER", + query_clause, ) + txn.execute(sql, query_params) + + for (user_id, device_id, display_name, key_json) in txn: + assert device_id is not None + if include_deleted_devices: + deleted_devices.remove((user_id, device_id)) + result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( + display_name, db_to_json(key_json) if key_json else None + ) + if include_deleted_devices: for user_id, device_id in deleted_devices: + if device_id is None: + continue result.setdefault(user_id, {})[device_id] = None return result |