From fa907025f4b263d27c2b338fb0fe86d257d74fa8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 5 Oct 2023 11:07:38 -0400 Subject: Remove manys calls to cursor_to_dict (#16431) This avoids calling cursor_to_dict and then immediately unpacking the values in the dict for other users. By not creating the intermediate dictionary we can avoid allocating the dictionary and strings for the keys, which should generally be more performant. Additionally this improves type hints by avoid Dict[str, Any] dictionaries coming out of the database layer. --- synapse/storage/databases/main/devices.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse/storage/databases/main/devices.py') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index df596f35f9..9f3804a504 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_devices_not_accessed_since_txn( txn: LoggingTransaction, - ) -> List[Dict[str, str]]: + ) -> List[Tuple[str, str]]: sql = """ SELECT user_id, device_id FROM devices WHERE last_seen < ? AND hidden = FALSE """ txn.execute(sql, (since_ms,)) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, str]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_devices_not_accessed_since", @@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) devices: Dict[str, List[str]] = {} - for row in rows: + for user_id, device_id in rows: # Remote devices are never stale from our point of view. - if self.hs.is_mine_id(row["user_id"]): - user_devices = devices.setdefault(row["user_id"], []) - user_devices.append(row["device_id"]) + if self.hs.is_mine_id(user_id): + user_devices = devices.setdefault(user_id, []) + user_devices.append(device_id) return devices -- cgit 1.5.1 From a4904dcb04b31ce8ed0deaa2c5c80657780f6618 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Oct 2023 13:24:56 -0400 Subject: Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444) --- changelog.d/16444.misc | 1 + synapse/storage/database.py | 18 ++-- synapse/storage/databases/main/deviceinbox.py | 42 ++++---- synapse/storage/databases/main/devices.py | 49 ++++++---- synapse/storage/databases/main/end_to_end_keys.py | 19 ++-- synapse/storage/databases/main/event_federation.py | 107 +++++++++++---------- synapse/storage/databases/main/events.py | 79 ++++++++------- .../storage/databases/main/events_bg_updates.py | 62 ++++++------ synapse/storage/databases/main/events_worker.py | 36 ++++--- synapse/storage/databases/main/keys.py | 46 +++++---- synapse/storage/databases/main/presence.py | 51 ++++++---- synapse/storage/databases/main/push_rule.py | 97 +++++++++++++------ synapse/storage/databases/main/relations.py | 19 ++-- synapse/storage/databases/main/room.py | 19 ++-- synapse/storage/databases/main/roommember.py | 78 ++++++++------- synapse/storage/databases/main/state.py | 62 +++++++----- synapse/storage/databases/main/stats.py | 37 +++---- synapse/storage/databases/main/transactions.py | 28 ++++-- synapse/storage/databases/main/ui_auth.py | 41 ++++---- synapse/storage/databases/main/user_directory.py | 54 ++++++----- .../storage/databases/main/user_erasure_store.py | 19 ++-- synapse/storage/databases/state/store.py | 54 +++++++---- tests/storage/test_event_chain.py | 64 +++++++----- 23 files changed, 640 insertions(+), 442 deletions(-) create mode 100644 changelog.d/16444.misc (limited to 'synapse/storage/databases/main/devices.py') diff --git a/changelog.d/16444.misc b/changelog.d/16444.misc new file mode 100644 index 0000000000..bd7cdd42af --- /dev/null +++ b/changelog.d/16444.misc @@ -0,0 +1 @@ +Reduce memory allocations. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7714ec2bf9..81f661160c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1874,9 +1874,9 @@ class DatabasePool: keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows. Filters rows by whether the value of `column` is in `iterable`. @@ -1888,10 +1888,13 @@ class DatabasePool: keyvalues: dict of column names and values to select the rows with desc: description of the transaction, for logging and metrics batch_size: the number of rows for each select query + + Returns: + The results as a list of tuples. """ keyvalues = keyvalues or {} - results: List[Dict[str, Any]] = [] + results: List[Tuple[Any, ...]] = [] for chunk in batch_iter(iterable, batch_size): rows = await self.runInteraction( @@ -1918,9 +1921,9 @@ class DatabasePool: iterable: Collection[Any], keyvalues: Dict[str, Any], retcols: Iterable[str], - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows. Filters rows by whether the value of `column` is in `iterable`. @@ -1931,6 +1934,9 @@ class DatabasePool: iterable: list keyvalues: dict of column names and values to select the rows with retcols: list of strings giving the names of the columns to return + + Returns: + The results as a list of tuples. """ if not iterable: return [] @@ -1949,7 +1955,7 @@ class DatabasePool: ) txn.execute(sql, values) - return cls.cursor_to_dict(txn) + return txn.fetchall() async def simple_update( self, diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 744e98c6d0..1cf649d371 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` if not device_ids_to_query: - user_device_dicts = self.db_pool.simple_select_many_txn( - txn, - table="devices", - column="user_id", - iterable=user_ids_to_query, - keyvalues={"hidden": False}, - retcols=("device_id",), + user_device_dicts = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"hidden": False}, + retcols=("device_id",), + ), ) - device_ids_to_query.update( - {row["device_id"] for row in user_device_dicts} - ) + device_ids_to_query.update({row[0] for row in user_device_dicts}) if not device_ids_to_query: # We've ended up with no devices to query. @@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore): # We exclude hidden devices (such as cross-signing keys) here as they are # not expected to receive to-device messages. - rows = self.db_pool.simple_select_many_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - column="device_id", - iterable=devices, - retcols=("device_id",), + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + column="device_id", + iterable=devices, + retcols=("device_id",), + ), ) - for row in rows: + for (device_id,) in rows: # Only insert into the local inbox if the device exists on # this server - device_id = row["device_id"] - with start_active_span("serialise_to_device_message"): msg = messages_by_device[device_id] set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"]) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9f3804a504..fc23d18eba 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): async def get_device_list_last_stream_id_for_remotes( self, user_ids: Iterable[str] ) -> Mapping[str, Optional[str]]: - rows = await self.db_pool.simple_select_many_batch( - table="device_lists_remote_extremeties", - column="user_id", - iterable=user_ids, - retcols=("user_id", "stream_id"), - desc="get_device_list_last_stream_id_for_remotes", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id"), + desc="get_device_list_last_stream_id_for_remotes", + ), ) results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} - results.update({row["user_id"]: row["stream_id"] for row in rows}) + results.update(rows) return results @@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The IDs of users whose device lists need resync. """ if user_ids: - rows = await self.db_pool.simple_select_many_batch( - table="device_lists_remote_resync", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync_with_iterable", + row_tuples = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="device_lists_remote_resync", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync_with_iterable", + ), ) + + return {row[0] for row in row_tuples} else: - rows = await self.db_pool.simple_select_list( - table="device_lists_remote_resync", - keyvalues=None, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync", + rows = cast( + List[Dict[str, str]], + await self.db_pool.simple_select_list( + table="device_lists_remote_resync", + keyvalues=None, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync", + ), ) - return {row["user_id"] for row in rows} + return {row["user_id"] for row in rows} async def mark_remote_users_device_caches_as_stale( self, user_ids: StrCollection diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 749ae54e20..f13d776b0d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker A map from (algorithm, key_id) to json string for key """ - rows = await self.db_pool.simple_select_many_batch( - table="e2e_one_time_keys_json", - column="key_id", - iterable=key_ids, - retcols=("algorithm", "key_id", "key_json"), - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="add_e2e_one_time_keys_check", + rows = cast( + List[Tuple[str, str, str]], + await self.db_pool.simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="add_e2e_one_time_keys_check", + ), ) - result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows} log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index afffa54985..4f80ce75cc 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1049,15 +1049,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas Args: event_ids: The event IDs to calculate the max depth of. """ - rows = await self.db_pool.simple_select_many_batch( - table="events", - column="event_id", - iterable=event_ids, - retcols=( - "event_id", - "depth", + rows = cast( + List[Tuple[str, int]], + await self.db_pool.simple_select_many_batch( + table="events", + column="event_id", + iterable=event_ids, + retcols=( + "event_id", + "depth", + ), + desc="get_max_depth_of", ), - desc="get_max_depth_of", ) if not rows: @@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas else: max_depth_event_id = "" current_max_depth = 0 - for row in rows: - if row["depth"] > current_max_depth: - max_depth_event_id = row["event_id"] - current_max_depth = row["depth"] + for event_id, depth in rows: + if depth > current_max_depth: + max_depth_event_id = event_id + current_max_depth = depth return max_depth_event_id, current_max_depth @@ -1078,15 +1081,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas Args: event_ids: The event IDs to calculate the max depth of. """ - rows = await self.db_pool.simple_select_many_batch( - table="events", - column="event_id", - iterable=event_ids, - retcols=( - "event_id", - "depth", + rows = cast( + List[Tuple[str, int]], + await self.db_pool.simple_select_many_batch( + table="events", + column="event_id", + iterable=event_ids, + retcols=( + "event_id", + "depth", + ), + desc="get_min_depth_of", ), - desc="get_min_depth_of", ) if not rows: @@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas else: min_depth_event_id = "" current_min_depth = MAX_DEPTH - for row in rows: - if row["depth"] < current_min_depth: - min_depth_event_id = row["event_id"] - current_min_depth = row["depth"] + for event_id, depth in rows: + if depth < current_min_depth: + min_depth_event_id = event_id + current_min_depth = depth return min_depth_event_id, current_min_depth @@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas A filtered down list of `event_ids` that have previous failed pull attempts. """ - rows = await self.db_pool.simple_select_many_batch( - table="event_failed_pull_attempts", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id",), - desc="get_event_ids_with_failed_pull_attempts", + rows = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id",), + desc="get_event_ids_with_failed_pull_attempts", + ), ) - event_ids_with_failed_pull_attempts: Set[str] = { - row["event_id"] for row in rows - } - - return event_ids_with_failed_pull_attempts + return {row[0] for row in rows} @trace async def get_event_ids_to_not_pull_from_backoff( @@ -1585,32 +1590,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas A dictionary of event_ids that should not be attempted to be pulled and the next timestamp at which we may try pulling them again. """ - event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( - table="event_failed_pull_attempts", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=( - "event_id", - "last_attempt_ts", - "num_attempts", + event_failed_pull_attempts = cast( + List[Tuple[str, int, int]], + await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", ), - desc="get_event_ids_to_not_pull_from_backoff", ) current_time = self._clock.time_msec() event_ids_with_backoff = {} - for event_failed_pull_attempt in event_failed_pull_attempts: - event_id = event_failed_pull_attempt["event_id"] + for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts: # Exponential back-off (up to the upper bound) so we don't try to # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. backoff_end_time = ( - event_failed_pull_attempt["last_attempt_ts"] + last_attempt_ts + ( 2 ** min( - event_failed_pull_attempt["num_attempts"], + num_attempts, BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, ) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d4dcdb898c..ef6766b5e0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -27,6 +27,7 @@ from typing import ( Optional, Set, Tuple, + Union, cast, ) @@ -501,16 +502,19 @@ class PersistEventsStore: # We ignore legacy rooms that we aren't filling the chain cover index # for. - rows = self.db_pool.simple_select_many_txn( - txn, - table="rooms", - column="room_id", - iterable={event.room_id for event in events if event.is_state()}, - keyvalues={}, - retcols=("room_id", "has_auth_chain_index"), + rows = cast( + List[Tuple[str, Optional[Union[int, bool]]]], + self.db_pool.simple_select_many_txn( + txn, + table="rooms", + column="room_id", + iterable={event.room_id for event in events if event.is_state()}, + keyvalues={}, + retcols=("room_id", "has_auth_chain_index"), + ), ) rooms_using_chain_index = { - row["room_id"] for row in rows if row["has_auth_chain_index"] + room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index } state_events = { @@ -571,19 +575,18 @@ class PersistEventsStore: # We check if there are any events that need to be handled in the rooms # we're looking at. These should just be out of band memberships, where # we didn't have the auth chain when we first persisted. - rows = db_pool.simple_select_many_txn( - txn, - table="event_auth_chain_to_calculate", - keyvalues={}, - column="room_id", - iterable=set(event_to_room_id.values()), - retcols=("event_id", "type", "state_key"), + auth_chain_to_calc_rows = cast( + List[Tuple[str, str, str]], + db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="room_id", + iterable=set(event_to_room_id.values()), + retcols=("event_id", "type", "state_key"), + ), ) - for row in rows: - event_id = row["event_id"] - event_type = row["type"] - state_key = row["state_key"] - + for event_id, event_type, state_key in auth_chain_to_calc_rows: # (We could pull out the auth events for all rows at once using # simple_select_many, but this case happens rarely and almost always # with a single row.) @@ -753,23 +756,31 @@ class PersistEventsStore: # Step 1, fetch all existing links from all the chains we've seen # referenced. chain_links = _LinkMap() - rows = db_pool.simple_select_many_txn( - txn, - table="event_auth_chain_links", - column="origin_chain_id", - iterable={chain_id for chain_id, _ in chain_map.values()}, - keyvalues={}, - retcols=( - "origin_chain_id", - "origin_sequence_number", - "target_chain_id", - "target_sequence_number", + auth_chain_rows = cast( + List[Tuple[int, int, int, int]], + db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_links", + column="origin_chain_id", + iterable={chain_id for chain_id, _ in chain_map.values()}, + keyvalues={}, + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), ), ) - for row in rows: + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in auth_chain_rows: chain_links.add_link( - (row["origin_chain_id"], row["origin_sequence_number"]), - (row["target_chain_id"], row["target_sequence_number"]), + (origin_chain_id, origin_sequence_number), + (target_chain_id, target_sequence_number), new=False, ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index daef3685b0..c5fce1c82b 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self.db_pool.simple_select_many_txn( - txn, - table="event_json", - column="event_id", - iterable=chunk, - retcols=["event_id", "json"], - keyvalues={}, + ev_rows = cast( + List[Tuple[str, str]], + self.db_pool.simple_select_many_txn( + txn, + table="event_json", + column="event_id", + iterable=chunk, + retcols=["event_id", "json"], + keyvalues={}, + ), ) - for row in ev_rows: - event_id = row["event_id"] - event_json = db_to_json(row["json"]) + for event_id, json in ev_rows: + event_json = db_to_json(json) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): @@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self.db_pool.simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=to_delete, - keyvalues={}, - retcols=("room_id",), + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=to_delete, + keyvalues={}, + retcols=("room_id",), + ), ) - room_ids = {row["room_id"] for row in rows} + room_ids = {row[0] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] @@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): count = len(rows) # We also need to fetch the auth events for them. - auth_events = self.db_pool.simple_select_many_txn( - txn, - table="event_auth", - column="event_id", - iterable=event_to_room_id, - keyvalues={}, - retcols=("event_id", "auth_id"), + auth_events = cast( + List[Tuple[str, str]], + self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ), ) event_to_auth_chain: Dict[str, List[str]] = {} - for row in auth_events: - event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + for event_id, auth_id in auth_events: + event_to_auth_chain.setdefault(event_id, []).append(auth_id) # Calculate and persist the chain cover index for this set of events. # diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b788d70fc5..8af638d60f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = await self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ), ) - return {r["event_id"] for r in rows} + return {r[0] for r in rows} @trace @tag_args @@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore): a dict mapping from event id to partial-stateness. We return True for any of the events which are unknown (or are outliers). """ - result = await self.db_pool.simple_select_many_batch( - table="partial_state_events", - column="event_id", - iterable=event_ids, - retcols=["event_id"], - desc="get_partial_state_events", + result = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="partial_state_events", + column="event_id", + iterable=event_ids, + retcols=["event_id"], + desc="get_partial_state_events", + ), ) # convert the result to a dict, to make @cachedList work - partial = {r["event_id"] for r in result} + partial = {r[0] for r in result} return {e_id: e_id in partial for e_id in event_ids} @cached() diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 889c578b9c..ea797864b9 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import json import logging -from typing import Dict, Iterable, Mapping, Optional, Tuple +from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore): If we have multiple entries for a given key ID, returns the most recent. """ - rows = await self.db_pool.simple_select_many_batch( - table="server_keys_json", - column="key_id", - iterable=key_ids, - keyvalues={"server_name": server_name}, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", + rows = cast( + List[Tuple[str, str, int, int, Union[bytes, memoryview]]], + await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ), - desc="get_server_keys_json_for_remote", ) if not rows: return {} - # We sort the rows so that the most recently added entry is picked up. - rows.sort(key=lambda r: r["ts_added_ms"]) + # We sort the rows by ts_added_ms so that the most recently added entry + # will stomp over older entries in the dictionary. + rows.sort(key=lambda r: r[2]) return { - row["key_id"]: FetchKeyResultForRemote( + key_id: FetchKeyResultForRemote( # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], + key_json=bytes(key_json), + valid_until_ts=ts_valid_until_ms, + added_ts=ts_added_ms, ) - for row in rows + for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows } async def get_all_server_keys_json_for_remote( @@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore): if not rows: return {} + # We sort the rows by ts_added_ms so that the most recently added entry + # will stomp over older entries in the dictionary. rows.sort(key=lambda r: r["ts_added_ms"]) return { diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 519f05fb60..3b444d2d07 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -261,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) async def get_presence_for_users( self, user_ids: Iterable[str] ) -> Mapping[str, UserPresenceState]: - rows = await self.db_pool.simple_select_many_batch( - table="presence_stream", - column="user_id", - iterable=user_ids, - keyvalues={}, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", + # TODO All these columns are nullable, but we don't expect that: + # https://github.com/matrix-org/synapse/issues/16467 + rows = cast( + List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], + await self.db_pool.simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + desc="get_presence_for_users", ), - desc="get_presence_for_users", ) - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return {row["user_id"]: UserPresenceState(**row) for row in rows} + return { + user_id: UserPresenceState( + user_id=user_id, + state=state, + last_active_ts=last_active_ts, + last_federation_update_ts=last_federation_update_ts, + last_user_sync_ts=last_user_sync_ts, + status_msg=status_msg, + currently_active=bool(currently_active), + ) + for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows + } async def should_user_receive_full_presence_with_token( self, @@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) limit = 100 offset = 0 while True: + # TODO All these columns are nullable, but we don't expect that: + # https://github.com/matrix-org/synapse/issues/16467 rows = cast( List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 923166974c..f5356e7f80 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -62,20 +62,34 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], + rawrules: List[Tuple[str, int, str, str]], enabled_map: Dict[str, bool], experimental_config: ExperimentalConfig, ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. + + Args: + rawrules: List of tuples of: + * rule ID + * Priority lass + * Conditions (as serialized JSON) + * Actions (as serialized JSON) + enabled_map: A dictionary of rule ID to a boolean of whether the rule is + enabled. This might not include all rule IDs from rawrules. + experimental_config: The `experimental_features` section of the Synapse + config. (Used to check if various features are enabled.) + + Returns: + A new FilteredPushRules object. """ ruleslist = [ PushRule.from_db( - rule_id=rawrule["rule_id"], - priority_class=rawrule["priority_class"], - conditions=rawrule["conditions"], - actions=rawrule["actions"], + rule_id=rawrule[0], + priority_class=rawrule[1], + conditions=rawrule[2], + actions=rawrule[3], ) for rawrule in rawrules ] @@ -183,7 +197,19 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules( + [ + ( + row["rule_id"], + row["priority_class"], + row["conditions"], + row["actions"], + ) + for row in rows + ], + enabled_map, + self.hs.config.experimental, + ) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -221,21 +247,36 @@ class PushRulesWorkerStore( if not user_ids: return {} - raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} + raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = { + user_id: [] for user_id in user_ids + } - rows = await self.db_pool.simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=("*",), - desc="bulk_get_push_rules", - batch_size=1000, + rows = cast( + List[Tuple[str, str, int, int, str, str]], + await self.db_pool.simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="bulk_get_push_rules", + batch_size=1000, + ), ) - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + # Sort by highest priority_class, then highest priority. + rows.sort(key=lambda row: (-int(row[2]), -int(row[3]))) - for row in rows: - raw_rules.setdefault(row["user_name"], []).append(row) + for user_name, rule_id, priority_class, _, conditions, actions in rows: + raw_rules.setdefault(user_name, []).append( + (rule_id, priority_class, conditions, actions) + ) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) @@ -256,17 +297,19 @@ class PushRulesWorkerStore( results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} - rows = await self.db_pool.simple_select_many_batch( - table="push_rules_enable", - column="user_name", - iterable=user_ids, - retcols=("user_name", "rule_id", "enabled"), - desc="bulk_get_push_rules_enabled", - batch_size=1000, + rows = cast( + List[Tuple[str, str, Optional[int]]], + await self.db_pool.simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled"), + desc="bulk_get_push_rules_enabled", + batch_size=1000, + ), ) - for row in rows: - enabled = bool(row["enabled"]) - results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled + for user_name, rule_id, enabled in rows: + results.setdefault(user_name, {})[rule_id] = bool(enabled) return results async def get_all_push_rule_updates( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 9246b418f5..7f40e2c446 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore): def get_all_relation_ids_for_event_with_types_txn( txn: LoggingTransaction, ) -> List[str]: - rows = self.db_pool.simple_select_many_txn( - txn=txn, - table="event_relations", - column="relation_type", - iterable=relation_types, - keyvalues={"relates_to_id": event_id}, - retcols=["event_id"], + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ), ) - return [row["event_id"] for row in rows] + return [row[0] for row in rows] return await self.db_pool.runInteraction( desc="get_all_relation_ids_for_event_with_types", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 1d4d99932b..9d24d2c347 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): complete. """ - rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch( - table="partial_state_rooms", - column="room_id", - iterable=room_ids, - retcols=("room_id",), - desc="is_partial_state_room_batched", - ) - partial_state_rooms = {row_dict["room_id"] for row_dict in rows} + rows = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="partial_state_rooms", + column="room_id", + iterable=room_ids, + retcols=("room_id",), + desc="is_partial_state_room_batched", + ), + ) + partial_state_rooms = {row[0] for row in rows} return {room_id: room_id in partial_state_rooms for room_id in room_ids} async def get_join_event_id_and_device_lists_stream_id_for_partial_state( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bbe08368db..3a87eba430 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -27,6 +27,7 @@ from typing import ( Set, Tuple, Union, + cast, ) import attr @@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): Map from user_id to set of rooms that is currently in. """ - rows = await self.db_pool.simple_select_many_batch( - table="current_state_events", - column="state_key", - iterable=user_ids, - retcols=( - "state_key", - "room_id", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", ), - keyvalues={ - "type": EventTypes.Member, - "membership": Membership.JOIN, - }, - desc="get_rooms_for_users", ) user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} - for row in rows: - user_rooms[row["state_key"]].add(row["room_id"]) + for state_key, room_id in rows: + user_rooms[state_key].add(room_id) return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): Map from event ID to `user_id`, or None if event is not a join. """ - rows = await self.db_pool.simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=event_ids, - retcols=("user_id", "event_id"), - keyvalues={"membership": Membership.JOIN}, - batch_size=1000, - desc="_get_user_ids_from_membership_event_ids", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("event_id", "user_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=1000, + desc="_get_user_ids_from_membership_event_ids", + ), ) - return {row["event_id"]: row["user_id"] for row in rows} + return dict(rows) @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): membership event, otherwise the value is None. """ - rows = await self.db_pool.simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="get_membership_from_event_ids", + rows = cast( + List[Tuple[str, str, str]], + await self.db_pool.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ), ) return { - row["event_id"]: EventIdMembership( - membership=row["membership"], user_id=row["user_id"] - ) - for row in rows + event_id: EventIdMembership(membership=membership, user_id=user_id) + for user_id, membership, event_id in rows } async def is_local_host_in_room_ignoring_users( diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5eaaff5b68..598025dd91 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -20,10 +20,12 @@ from typing import ( Collection, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, + cast, ) import attr @@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: RuntimeError if the state is unknown at any of the given events """ - rows = await self.db_pool.simple_select_many_batch( - table="event_to_state_groups", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id", "state_group"), - desc="_get_state_group_for_events", + rows = cast( + List[Tuple[str, int]], + await self.db_pool.simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group"), + desc="_get_state_group_for_events", + ), ) - res = {row["event_id"]: row["state_group"] for row in rows} + res = dict(rows) for e in event_ids: if e not in res: raise RuntimeError("No state group for unknown or outlier event %s" % e) @@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): The subset of state groups that are referenced. """ - rows = await self.db_pool.simple_select_many_batch( - table="event_to_state_groups", - column="state_group", - iterable=state_groups, - keyvalues={}, - retcols=("DISTINCT state_group",), - desc="get_referenced_state_groups", + rows = cast( + List[Tuple[int]], + await self.db_pool.simple_select_many_batch( + table="event_to_state_groups", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("DISTINCT state_group",), + desc="get_referenced_state_groups", + ), ) - return {row["state_group"] for row in rows} + return {row[0] for row in rows} async def update_state_for_partial_state_event( self, @@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): # potentially stale, since there may have been a period where the # server didn't share a room with the remote user and therefore may # have missed any device updates. - rows = self.db_pool.simple_select_many_txn( - txn, - table="current_state_events", - column="room_id", - iterable=to_delete, - keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, - retcols=("state_key",), + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="current_state_events", + column="room_id", + iterable=to_delete, + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + retcols=("state_key",), + ), ) - potentially_left_users = {row["state_key"] for row in rows} + potentially_left_users = {row[0] for row in rows} # Now lets actually delete the rooms from the DB. self.db_pool.simple_delete_many_txn( diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9d403919e4..5b2d0ba870 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore): ) -> Tuple[List[str], Dict[str, int], int, List[str], int]: pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined] - rows = self.db_pool.simple_select_many_txn( - txn, - table="current_state_events", - column="type", - iterable=[ - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.RoomHistoryVisibility, - EventTypes.RoomEncryption, - EventTypes.Name, - EventTypes.Topic, - EventTypes.RoomAvatar, - EventTypes.CanonicalAlias, - ], - keyvalues={"room_id": room_id, "state_key": ""}, - retcols=["event_id"], + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="current_state_events", + column="type", + iterable=[ + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, + EventTypes.RoomEncryption, + EventTypes.Name, + EventTypes.Topic, + EventTypes.RoomAvatar, + EventTypes.CanonicalAlias, + ], + keyvalues={"room_id": room_id, "state_key": ""}, + retcols=["event_id"], + ), ) - event_ids = cast(List[str], [row["event_id"] for row in rows]) + event_ids = [row[0] for row in rows] txn.execute( """ diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index f35757280d..c4a6475060 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): async def get_destination_retry_timings_batch( self, destinations: StrCollection ) -> Mapping[str, Optional[DestinationRetryTimings]]: - rows = await self.db_pool.simple_select_many_batch( - table="destinations", - iterable=destinations, - column="destination", - retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), - desc="get_destination_retry_timings_batch", + rows = cast( + List[Tuple[str, Optional[int], Optional[int], Optional[int]]], + await self.db_pool.simple_select_many_batch( + table="destinations", + iterable=destinations, + column="destination", + retcols=( + "destination", + "failure_ts", + "retry_last_ts", + "retry_interval", + ), + desc="get_destination_retry_timings_batch", + ), ) return { - row.pop("destination"): DestinationRetryTimings(**row) - for row in rows - if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] + destination: DestinationRetryTimings( + failure_ts, retry_last_ts, retry_interval + ) + for destination, failure_ts, retry_last_ts, retry_interval in rows + if retry_last_ts and failure_ts and retry_interval } async def set_destination_retry_timings( diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index f38bedbbcd..919c66f553 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore): # If a registration token was used, decrement the pending counter # before deleting the session. - rows = self.db_pool.simple_select_many_txn( - txn, - table="ui_auth_sessions_credentials", - column="session_id", - iterable=session_ids, - keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, - retcols=["result"], + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ), ) # Get the tokens used and how much pending needs to be decremented by. @@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore): # registration token stage for that session will be True. # If a token was used to authenticate, but registration was # never completed, the result will be the token used. - token = db_to_json(r["result"]) + token = db_to_json(r[0]) if isinstance(token, str): token_counts[token] = token_counts.get(token, 0) + 1 # Update the `pending` counters. if len(token_counts) > 0: - token_rows = self.db_pool.simple_select_many_txn( - txn, - table="registration_tokens", - column="token", - iterable=list(token_counts.keys()), - keyvalues={}, - retcols=["token", "pending"], + token_rows = cast( + List[Tuple[str, int]], + self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ), ) - for token_row in token_rows: - token = token_row["token"] - new_pending = token_row["pending"] - token_counts[token] + for token, pending in token_rows: + new_pending = pending - token_counts[token] self.db_pool.simple_update_one_txn( txn, table="registration_tokens", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f0dc31fee6..23eb92c514 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -410,25 +410,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) # Next fetch their profiles. Note that not all users have profiles. - profile_rows = self.db_pool.simple_select_many_txn( - txn, - table="profiles", - column="full_user_id", - iterable=list(users_to_insert), - retcols=( - "full_user_id", - "displayname", - "avatar_url", + profile_rows = cast( + List[Tuple[str, Optional[str], Optional[str]]], + self.db_pool.simple_select_many_txn( + txn, + table="profiles", + column="full_user_id", + iterable=list(users_to_insert), + retcols=( + "full_user_id", + "displayname", + "avatar_url", + ), + keyvalues={}, ), - keyvalues={}, ) profiles = { - row["full_user_id"]: _UserDirProfile( - row["full_user_id"], - row["displayname"], - row["avatar_url"], - ) - for row in profile_rows + full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url) + for full_user_id, displayname, avatar_url in profile_rows } profiles_to_insert = [ @@ -517,18 +516,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined] ] - rows = self.db_pool.simple_select_many_txn( - txn, - table="users", - column="name", - iterable=users, - keyvalues={ - "deactivated": 0, - }, - retcols=("name", "user_type"), + rows = cast( + List[Tuple[str, Optional[str]]], + self.db_pool.simple_select_many_txn( + txn, + table="users", + column="name", + iterable=users, + keyvalues={ + "deactivated": 0, + }, + retcols=("name", "user_type"), + ), ) - return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT] + return [name for name, user_type in rows if user_type != UserTypes.SUPPORT] async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 06fcbe5e54..8bd58c6e3d 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Mapping +from typing import Iterable, List, Mapping, Tuple, cast from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore): Returns: for each user, whether the user has requested erasure. """ - rows = await self.db_pool.simple_select_many_batch( - table="erased_users", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="are_users_erased", + rows = cast( + List[Tuple[str]], + await self.db_pool.simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ), ) - erased_users = {row["user_id"] for row in rows} + erased_users = {row[0] for row in rows} return {u: u in erased_users for u in user_ids} diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 6984d11352..09d2a8c5b3 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,7 +13,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) import attr @@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self.db_pool.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), + rows = cast( + List[Tuple[int]], + self.db_pool.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ), ) remaining_state_groups = { - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete + state_group + for state_group, in rows + if state_group not in state_groups_to_delete } logger.info( @@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): A mapping from state group to previous state group. """ - rows = await self.db_pool.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", + rows = cast( + List[Tuple[int, int]], + await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group", "prev_state_group"), + desc="get_previous_state_groups", + ), ) - return {row["state_group"]: row["prev_state_group"] for row in rows} + return dict(rows) async def purge_room_state( self, room_id: str, state_groups_to_delete: Collection[int] diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index b55dd07f14..2f6499966c 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, cast from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest @@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase): self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: # Fetch the map from event ID -> (chain ID, sequence number) - rows = self.get_success( - self.store.db_pool.simple_select_many_batch( - table="event_auth_chains", - column="event_id", - iterable=[e.event_id for e in events], - retcols=("event_id", "chain_id", "sequence_number"), - keyvalues={}, - ) + rows = cast( + List[Tuple[str, int, int]], + self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ), ) chain_map = { - row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + event_id: (chain_id, sequence_number) + for event_id, chain_id, sequence_number in rows } # Fetch all the links and pass them to the _LinkMap. - rows = self.get_success( - self.store.db_pool.simple_select_many_batch( - table="event_auth_chain_links", - column="origin_chain_id", - iterable=[chain_id for chain_id, _ in chain_map.values()], - retcols=( - "origin_chain_id", - "origin_sequence_number", - "target_chain_id", - "target_sequence_number", - ), - keyvalues={}, - ) + auth_chain_rows = cast( + List[Tuple[int, int, int, int]], + self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ), ) link_map = _LinkMap() - for row in rows: + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in auth_chain_rows: added = link_map.add_link( - (row["origin_chain_id"], row["origin_sequence_number"]), - (row["target_chain_id"], row["target_sequence_number"]), + (origin_chain_id, origin_sequence_number), + (target_chain_id, target_sequence_number), ) # We shouldn't have persisted any redundant links -- cgit 1.5.1 From 9407d5ba78d1e5275b5817ae9e6aedf7d1ca14f7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Oct 2023 13:01:36 -0400 Subject: Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505) This should use fewer allocations and improves type hints. --- changelog.d/16505.misc | 1 + synapse/handlers/deactivate_account.py | 4 +- synapse/handlers/sso.py | 5 +- synapse/storage/database.py | 31 +-- synapse/storage/databases/main/account_data.py | 18 +- synapse/storage/databases/main/appservice.py | 13 +- synapse/storage/databases/main/client_ips.py | 25 ++- synapse/storage/databases/main/devices.py | 70 +++--- synapse/storage/databases/main/e2e_room_keys.py | 49 ++-- synapse/storage/databases/main/event_federation.py | 18 +- .../databases/main/experimental_features.py | 15 +- synapse/storage/databases/main/keys.py | 35 +-- synapse/storage/databases/main/media_repository.py | 58 +++-- synapse/storage/databases/main/push_rule.py | 52 +++-- synapse/storage/databases/main/pusher.py | 20 +- synapse/storage/databases/main/registration.py | 60 +++-- synapse/storage/databases/main/relations.py | 15 +- synapse/storage/databases/main/room.py | 34 +-- synapse/storage/databases/main/roommember.py | 15 +- synapse/storage/databases/main/tags.py | 28 ++- synapse/storage/databases/main/ui_auth.py | 32 +-- synapse/storage/databases/state/store.py | 18 +- tests/handlers/test_stats.py | 14 +- tests/storage/databases/main/test_receipts.py | 20 +- tests/storage/test__base.py | 16 +- tests/storage/test_background_update.py | 35 +-- tests/storage/test_base.py | 4 +- tests/storage/test_client_ips.py | 250 ++++++++++----------- tests/storage/test_roommember.py | 40 ++-- tests/storage/test_state.py | 62 ++--- tests/storage/test_user_directory.py | 61 ++--- 31 files changed, 609 insertions(+), 509 deletions(-) create mode 100644 changelog.d/16505.misc (limited to 'synapse/storage/databases/main/devices.py') diff --git a/changelog.d/16505.misc b/changelog.d/16505.misc new file mode 100644 index 0000000000..bd7cdd42af --- /dev/null +++ b/changelog.d/16505.misc @@ -0,0 +1 @@ +Reduce memory allocations. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6a8f8f2fd1..370f4041fb 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -103,10 +103,10 @@ class DeactivateAccountHandler: # Attempt to unbind any known bound threepids to this account from identity # server(s). bound_threepids = await self.store.user_get_bound_threepids(user_id) - for threepid in bound_threepids: + for medium, address in bound_threepids: try: result = await self._identity_handler.try_unbind_threepid( - user_id, threepid["medium"], threepid["address"], id_server + user_id, medium, address, id_server ) except Exception: # Do we want this to be a fatal error or should we carry on? diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e9a544e754..62f2454f5d 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -1206,10 +1206,7 @@ class SsoHandler: # We have no guarantee that all the devices of that session are for the same # `user_id`. Hence, we have to iterate over the list of devices and log them out # one by one. - for device in devices: - user_id = device["user_id"] - device_id = device["device_id"] - + for user_id, device_id in devices: # If the user_id associated with that device/session is not the one we got # out of the `sub` claim, skip that device and show log an error. if expected_user_id is not None and user_id != expected_user_id: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 81f661160c..774d5c12f0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -606,13 +606,16 @@ class DatabasePool: If the background updates have not completed, wait 15 sec and check again. """ - updates = await self.simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", + updates = cast( + List[Tuple[str]], + await self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ), ) - background_update_names = [x["update_name"] for x in updates] + background_update_names = [x[0] for x in updates] for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): if update_name not in background_update_names: @@ -1804,9 +1807,9 @@ class DatabasePool: keyvalues: Optional[Dict[str, Any]], retcols: Collection[str], desc: str = "simple_select_list", - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows, returning the result as a list of tuples. Args: table: the table name @@ -1817,8 +1820,7 @@ class DatabasePool: desc: description of the transaction, for logging and metrics Returns: - A list of dictionaries, one per result row, each a mapping between the - column names from `retcols` and that column's value for the row. + A list of tuples, one per result row, each the retcolumn's value for the row. """ return await self.runInteraction( desc, @@ -1836,9 +1838,9 @@ class DatabasePool: table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. + more rows, returning the result as a list of tuples. Args: txn: Transaction object @@ -1849,8 +1851,7 @@ class DatabasePool: retcols: the names of the columns to return Returns: - A list of dictionaries, one per result row, each a mapping between the - column names from `retcols` and that column's value for the row. + A list of tuples, one per result row, each the retcolumn's value for the row. """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1863,7 +1864,7 @@ class DatabasePool: sql = "SELECT %s FROM %s" % (", ".join(retcols), table) txn.execute(sql) - return cls.cursor_to_dict(txn) + return txn.fetchall() async def simple_select_many_batch( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 84ef8136c2..d7482a1f4e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_room_txn( txn: LoggingTransaction, - ) -> Dict[str, JsonDict]: - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id, "room_id": room_id}, - ["account_data_type", "content"], + ) -> Dict[str, JsonMapping]: + rows = cast( + List[Tuple[str, str]], + self.db_pool.simple_select_list_txn( + txn, + table="room_account_data", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=["account_data_type", "content"], + ), ) return { - row["account_data_type"]: db_to_json(row["content"]) for row in rows + account_data_type: db_to_json(content) + for account_data_type, content in rows } return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 073a99cd84..fa7d1c469a 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore( Returns: A list of ApplicationServices, which may be empty. """ - results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state.value}, ["as_id"] + results = cast( + List[Tuple[str]], + await self.db_pool.simple_select_list( + table="application_services_state", + keyvalues={"state": state.value}, + retcols=("as_id",), + ), ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() services = [] - for res in results: + for (as_id,) in results: for service in as_list: - if service.id == res["as_id"]: + if service.id == as_id: services.append(service) return services diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 8be1511859..c006129625 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return { - (d["user_id"], d["device_id"]): DeviceLastConnectionInfo( - user_id=d["user_id"], - device_id=d["device_id"], - ip=d["ip"], - user_agent=d["user_agent"], - last_seen=d["last_seen"], + (user_id, device_id): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip=ip, + user_agent=user_agent, + last_seen=last_seen, ) - for d in res + for user_id, ip, user_agent, device_id, last_seen in res } async def _get_user_ip_and_agents_from_database( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fc23d18eba..0b75f6763a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): allow_none=True, ) - async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: + async def get_devices_by_user( + self, user_id: str + ) -> Dict[str, Dict[str, Optional[str]]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. @@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: Returns: A mapping from device_id to a dict containing "device_id", "user_id" - and "display_name" for each device. + and "display_name" for each device. Display name may be null. """ - devices = await self.db_pool.simple_select_list( - table="devices", - keyvalues={"user_id": user_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_devices_by_user", + devices = cast( + List[Tuple[str, str, Optional[str]]], + await self.db_pool.simple_select_list( + table="devices", + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user", + ), ) - return {d["device_id"]: d for d in devices} + return { + d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]} + for d in devices + } async def get_devices_by_auth_provider_session_id( self, auth_provider_id: str, auth_provider_session_id: str - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, str]]: """Retrieve the list of devices associated with a SSO IdP session ID. Args: @@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): Returns: A list of dicts containing the device_id and the user_id of each device """ - return await self.db_pool.simple_select_list( - table="device_auth_providers", - keyvalues={ - "auth_provider_id": auth_provider_id, - "auth_provider_session_id": auth_provider_session_id, - }, - retcols=("user_id", "device_id"), - desc="get_devices_by_auth_provider_session_id", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ), ) @trace @@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): async def get_cached_devices_for_user( self, user_id: str ) -> Mapping[str, JsonMapping]: - devices = await self.db_pool.simple_select_list( - table="device_lists_remote_cache", - keyvalues={"user_id": user_id}, - retcols=("device_id", "content"), - desc="get_cached_devices_for_user", + devices = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="device_lists_remote_cache", + keyvalues={"user_id": user_id}, + retcols=("device_id", "content"), + desc="get_cached_devices_for_user", + ), ) - return { - device["device_id"]: db_to_json(device["content"]) for device in devices - } + return {device[0]: db_to_json(device[1]) for device in devices} def get_cached_device_list_changes( self, @@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The IDs of users whose device lists need resync. """ if user_ids: - row_tuples = cast( + rows = cast( List[Tuple[str]], await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", @@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ), ) - - return {row[0] for row in row_tuples} else: rows = cast( - List[Dict[str, str]], + List[Tuple[str]], await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, @@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - return {row["user_id"] for row in rows} + return {row[0] for row in rows} async def mark_remote_users_device_caches_as_stale( self, user_ids: StrCollection diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index aac4cfb054..ad904a26a6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast from typing_extensions import Literal, TypedDict @@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): if session_id: keyvalues["session_id"] = session_id - rows = await self.db_pool.simple_select_list( - table="e2e_room_keys", - keyvalues=keyvalues, - retcols=( - "user_id", - "room_id", - "session_id", - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", + rows = cast( + List[Tuple[str, str, int, int, int, str]], + await self.db_pool.simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", ), - desc="get_e2e_room_keys", ) sessions: Dict[ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]] ] = {"rooms": {}} - for row in rows: - room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) - room_entry["sessions"][row["session_id"]] = { - "first_message_index": row["first_message_index"], - "forwarded_count": row["forwarded_count"], + for ( + room_id, + session_id, + first_message_index, + forwarded_count, + is_verified, + session_data, + ) in rows: + room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}}) + room_entry["sessions"][session_id] = { + "first_message_index": first_message_index, + "forwarded_count": forwarded_count, # is_verified must be returned to the client as a boolean - "is_verified": bool(row["is_verified"]), - "session_data": db_to_json(row["session_data"]), + "is_verified": bool(is_verified), + "session_data": db_to_json(session_data), } return sessions diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4f80ce75cc..f1b0991503 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # keeping only the forward extremities (i.e. the events not referenced # by other events in the queue). We do this so that we can always # backpaginate in all the events we have dropped. - rows = await self.db_pool.simple_select_list( - table="federation_inbound_events_staging", - keyvalues={"room_id": room_id}, - retcols=("event_id", "event_json"), - desc="prune_staged_events_in_room_fetch", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcols=("event_id", "event_json"), + desc="prune_staged_events_in_room_fetch", + ), ) # Find the set of events referenced by those in the queue, as well as # collecting all the event IDs in the queue. referenced_events: Set[str] = set() seen_events: Set[str] = set() - for row in rows: - event_id = row["event_id"] + for event_id, event_json in rows: seen_events.add(event_id) - event_d = db_to_json(row["event_json"]) + event_d = db_to_json(event_json) # We don't bother parsing the dicts into full blown event objects, # as that is needlessly expensive. diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py index 654f924019..60621edeef 100644 --- a/synapse/storage/databases/main/experimental_features.py +++ b/synapse/storage/databases/main/experimental_features.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, FrozenSet +from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): Returns: the features currently enabled for the user """ - enabled = await self.db_pool.simple_select_list( - "per_user_experimental_features", - {"user_id": user_id, "enabled": True}, - ["feature"], + enabled = cast( + List[Tuple[str]], + await self.db_pool.simple_select_list( + table="per_user_experimental_features", + keyvalues={"user_id": user_id, "enabled": True}, + retcols=("feature",), + ), ) - return frozenset(feature["feature"] for feature in enabled) + return frozenset(feature[0] for feature in enabled) async def set_features_for_user( self, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ea797864b9..ce88772f9e 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore): If we have multiple entries for a given key ID, returns the most recent. """ - rows = await self.db_pool.simple_select_list( - table="server_keys_json", - keyvalues={"server_name": server_name}, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", + rows = cast( + List[Tuple[str, str, int, int, Union[bytes, memoryview]]], + await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ), - desc="get_server_keys_json_for_remote", ) if not rows: @@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore): # We sort the rows by ts_added_ms so that the most recently added entry # will stomp over older entries in the dictionary. - rows.sort(key=lambda r: r["ts_added_ms"]) + rows.sort(key=lambda r: r[2]) return { - row["key_id"]: FetchKeyResultForRemote( + key_id: FetchKeyResultForRemote( # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], + key_json=bytes(key_json), + valid_until_ts=ts_valid_until_ms, + added_ts=ts_added_ms, ) - for row in rows + for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows } diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 2e6b176bd2..f82140b2e8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: - rows = await self.db_pool.simple_select_list( - "local_media_repository_thumbnails", - {"media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", ), - desc="get_local_media_thumbnails", ) return [ ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) for row in rows ] @@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_thumbnails( self, origin: str, media_id: str ) -> List[ThumbnailInfo]: - rows = await self.db_pool.simple_select_list( - "remote_media_cache_thumbnails", - {"media_origin": origin, "media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_remote_media_thumbnails", ), - desc="get_remote_media_thumbnails", ) return [ ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) for row in rows ] diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index f5356e7f80..22025eca56 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -179,46 +179,44 @@ class PushRulesWorkerStore( @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: - rows = await self.db_pool.simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", + rows = cast( + List[Tuple[str, int, int, str, str]], + await self.db_pool.simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_for_user", ), - desc="get_push_rules_for_user", ) - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + # Sort by highest priority_class, then highest priority. + rows.sort(key=lambda row: (-int(row[1]), -int(row[2]))) enabled_map = await self.get_push_rules_enabled_for_user(user_id) return _load_rules( - [ - ( - row["rule_id"], - row["priority_class"], - row["conditions"], - row["actions"], - ) - for row in rows - ], + [(row[0], row[1], row[3], row[4]) for row in rows], enabled_map, self.hs.config.experimental, ) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: - results = await self.db_pool.simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", + results = cast( + List[Tuple[str, Optional[Union[int, bool]]]], + await self.db_pool.simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ), ) - return {r["rule_id"]: bool(r["enabled"]) for r in results} + return {r[0]: bool(r[1]) for r in results} async def have_push_rules_changed_for_user( self, user_id: str, last_id: int diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c7eb7fc478..a6a1671bd6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore): async def get_throttle_params_by_room( self, pusher_id: int ) -> Dict[str, ThrottleParams]: - res = await self.db_pool.simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", + res = cast( + List[Tuple[str, Optional[int], Optional[int]]], + await self.db_pool.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ), ) params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = ThrottleParams( - row["last_sent_ts"], - row["throttle_ms"], + for room_id, last_sent_ts, throttle_ms in res: + params_by_room[room_id] = ThrottleParams( + last_sent_ts or 0, throttle_ms or 0 ) return params_by_room diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 9e8643ae4d..b0ef7be155 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: Tuples of (auth_provider, external_id) """ - res = await self.db_pool.simple_select_list( - table="user_external_ids", - keyvalues={"user_id": mxid}, - retcols=("auth_provider", "external_id"), - desc="get_external_ids_by_user", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="user_external_ids", + keyvalues={"user_id": mxid}, + retcols=("auth_provider", "external_id"), + desc="get_external_ids_by_user", + ), ) - return [(r["auth_provider"], r["external_id"]) for r in res] async def count_all_users(self) -> int: """Counts all users registered on the homeserver.""" @@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: - results = await self.db_pool.simple_select_list( - "user_threepids", - keyvalues={"user_id": user_id}, - retcols=["medium", "address", "validated_at", "added_at"], - desc="user_get_threepids", + results = cast( + List[Tuple[str, str, int, int]], + await self.db_pool.simple_select_list( + "user_threepids", + keyvalues={"user_id": user_id}, + retcols=["medium", "address", "validated_at", "added_at"], + desc="user_get_threepids", + ), ) - return [ThreepidResult(**r) for r in results] + return [ + ThreepidResult( + medium=r[0], + address=r[1], + validated_at=r[2], + added_at=r[3], + ) + for r in results + ] async def user_delete_threepid( self, user_id: str, medium: str, address: str @@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="add_user_bound_threepid", ) - async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: + async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. @@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_id: The ID of the user to retrieve threepids for Returns: - List of dictionaries containing the following keys: - medium (str): The medium of the threepid (e.g "email") - address (str): The address of the threepid (e.g "bob@example.com") - """ - return await self.db_pool.simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", + List of tuples of two strings: + medium: The medium of the threepid (e.g "email") + address: The address of the threepid (e.g "bob@example.com") + """ + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ), ) async def remove_user_bound_threepid( diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ce7bfd5146..419b2c7a22 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore): def get_all_relation_ids_for_event_txn( txn: LoggingTransaction, ) -> List[str]: - rows = self.db_pool.simple_select_list_txn( - txn=txn, - table="event_relations", - keyvalues={"relates_to_id": event_id}, - retcols=["event_id"], + rows = cast( + List[Tuple[str]], + self.db_pool.simple_select_list_txn( + txn=txn, + table="event_relations", + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ), ) - return [row["event_id"] for row in rows] + return [row[0] for row in rows] return await self.db_pool.runInteraction( desc="get_all_relation_ids_for_event", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 9d24d2c347..3e8fcf1975 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ room_servers: Dict[str, PartialStateResyncInfo] = {} - rows = await self.db_pool.simple_select_list( - table="partial_state_rooms", - keyvalues={}, - retcols=("room_id", "joined_via"), - desc="get_server_which_served_partial_join", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ), ) - for row in rows: - room_id = row["room_id"] - joined_via = row["joined_via"] + for room_id, joined_via in rows: room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) - rows = await self.db_pool.simple_select_list( - "partial_state_rooms_servers", - keyvalues=None, - retcols=("room_id", "server_name"), - desc="get_partial_state_rooms", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + "partial_state_rooms_servers", + keyvalues=None, + retcols=("room_id", "server_name"), + desc="get_partial_state_rooms", + ), ) - for row in rows: - room_id = row["room_id"] - server_name = row["server_name"] + for room_id, server_name in rows: entry = room_servers.get(room_id) if entry is None: # There is a foreign key constraint which enforces that every room_id in diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3a87eba430..a1627dffb7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): for fully-joined rooms. """ - rows = await self.db_pool.simple_select_list( - "current_state_events", - keyvalues={"room_id": room_id}, - retcols=("event_id", "membership"), - desc="has_completed_background_updates", + rows = cast( + List[Tuple[str, Optional[str]]], + await self.db_pool.simple_select_list( + "current_state_events", + keyvalues={"room_id": room_id}, + retcols=("event_id", "membership"), + desc="has_completed_background_updates", + ), ) - return {row["event_id"]: row["membership"] for row in rows} + return dict(rows) # TODO This returns a mutable object, which is generally confusing when using a cache. @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 61403a98cf..7deda7790e 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore): tag content. """ - rows = await self.db_pool.simple_select_list( - "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + rows = cast( + List[Tuple[str, str, str]], + await self.db_pool.simple_select_list( + "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] + ), ) tags_by_room: Dict[str, Dict[str, JsonDict]] = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) + for room_id, tag, content in rows: + room_tags = tags_by_room.setdefault(room_id, {}) + room_tags[tag] = db_to_json(content) return tags_by_room async def get_all_updated_tags( @@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A mapping of tags to tag content. """ - rows = await self.db_pool.simple_select_list( - table="room_tags", - keyvalues={"user_id": user_id, "room_id": room_id}, - retcols=("tag", "content"), - desc="get_tags_for_room", + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="room_tags", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcols=("tag", "content"), + desc="get_tags_for_room", + ), ) - return {row["tag"]: db_to_json(row["content"]) for row in rows} + return {tag: db_to_json(content) for tag, content in rows} async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 919c66f553..8ab7c42c4a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore): that auth-type. """ results = {} - for row in await self.db_pool.simple_select_list( - table="ui_auth_sessions_credentials", - keyvalues={"session_id": session_id}, - retcols=("stage_type", "result"), - desc="get_completed_ui_auth_stages", - ): - results[row["stage_type"]] = db_to_json(row["result"]) + rows = cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ), + ) + for stage_type, result in rows: + results[stage_type] = db_to_json(result) return results @@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore): Returns: List of user_agent/ip pairs """ - rows = await self.db_pool.simple_select_list( - table="ui_auth_sessions_ips", - keyvalues={"session_id": session_id}, - retcols=("user_agent", "ip"), - desc="get_user_agents_ips_to_ui_auth_session", + return cast( + List[Tuple[str, str]], + await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ), ) - return [(row["user_agent"], row["ip"]) for row in rows] async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 09d2a8c5b3..182e429174 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self.db_pool.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), + delta_ids = cast( + List[Tuple[str, str, str]], + self.db_pool.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ), ) return _GetStateGroupDelta( prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + { + (event_type, state_key): event_id + for event_type, state_key, event_id in delta_ids + }, ) return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d11ded6c5b..76c56d5434 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self) -> List[Dict[str, Any]]: - return await self.store.db_pool.simple_select_list( - "room_stats_state", None, retcols=("name", "topic", "canonical_alias") + async def get_all_room_state(self) -> List[Optional[str]]: + rows = cast( + List[Tuple[Optional[str]]], + await self.store.db_pool.simple_select_list( + "room_stats_state", None, retcols=("topic",) + ), ) + return [r[0] for r in rows] def _get_current_stats( self, stats_type: str, stat_id: str @@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 1) - self.assertEqual(r[0]["topic"], "foo") + self.assertEqual(r[0], "foo") def test_create_user(self) -> None: """ diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 71db47405e..98b01086bc 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: columns += expected_row.keys() - rows = self.get_success( + row_tuples = self.get_success( self.store.db_pool.simple_select_list( table=table, keyvalues={ @@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: self.assertEqual( - len(rows), + len(row_tuples), 1, f"Background update did not leave behind latest receipt in {table}", ) self.assertEqual( - rows[0], - { - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - **expected_row, - }, + row_tuples[0], + ( + room_id, + receipt_type, + user_id, + *expected_row.values(), + ), ) else: self.assertEqual( - len(rows), + len(row_tuples), 0, f"Background update did not remove all duplicate receipts from {table}", ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8bbf936ae9..8cbc974ac4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,7 +14,7 @@ # limitations under the License. import secrets -from typing import Generator, Tuple +from typing import Generator, List, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): ) def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) + yield from cast( + List[Tuple[int, str, str]], + self.get_success( + self.storage.db_pool.simple_select_list( + self.table_name, None, ["id, username, value"] + ) + ), ) - for i in res: - yield (i["id"], i["username"], i["value"]) - def test_upsert_many(self) -> None: """ Upsert_many will perform the upsert operation across a batch of data. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index abf7d0564d..3f5bfa09d4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple, cast from unittest.mock import AsyncMock, Mock import yaml @@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( @@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 256d28e4c9..e4a52c301e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 3 - self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) + self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)] self.mock_txn.description = (("colA", None, None, None, None, None, None),) ret = yield defer.ensureDeferred( @@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([(1,), (2,), (3,)], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 0c054a598f..8e4393d843 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, cast from unittest.mock import AsyncMock from parameterized import parameterized @@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345678000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345678000)] ) # Add another & trigger the storage loop @@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # Only one result, has been upserted. self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345878000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345878000)] ) @parameterized.expand([(False,), (True,)]) @@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) else: # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=( + "user_id", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + ), ), ) - self.assertEqual( - db_result, - [ - { - "user_id": user_id, - "device_id": device_id, - "ip": None, - "user_agent": None, - "last_seen": None, - }, - ], - ) + self.assertEqual(db_result, [(user_id, None, None, device_id, None)]) result = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) @@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ), ) self.assertCountEqual( db_result, [ - { - "user_id": user_id, - "device_id": device_id_1, - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "user_id": user_id, - "device_id": device_id_2, - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000), + (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000), ], ) @@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=("access_token", "ip", "user_agent", "last_seen"), + db_result = cast( + List[Tuple[str, str, str, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=("access_token", "ip", "user_agent", "last_seen"), + ), ), ) self.assertEqual( db_result, [ - { - "access_token": "access_token", - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "access_token": "access_token", - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + ("access_token", "ip_1", "user_agent_1", 12345678000), + ("access_token", "ip_2", "user_agent_2", 12345678000), ], ) @@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id, - "last_seen": 0, - } - ], + [("access_token", "ip", "user_agent", device_id, 0)], ) # Now advance by a couple of months self.reactor.advance(60 * 24 * 60 * 60) # We should get no results. - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual(result, []) @@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # ensure user1 is filtered out - self.assertEqual( - result, - [ - { - "access_token": access_token2, - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id2, - "last_seen": 0, - } - ], - ) + self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)]) class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index f4c4661aaf..36fcab06b5 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple, cast + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import Membership @@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - res = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only got one result back self.assertEqual(len(res), 1) # Check that alice's display name is "alice" - self.assertEqual(res[0]["display_name"], "alice") + self.assertEqual(res[0][0], "alice") # Grab the event_id to use later - event_id = res[0]["event_id"] + event_id = res[0][1] # Create a profile with the offending null byte in the display name new_profile = {"displayname": "ali\u0000ce"} @@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): tok=self.t_alice, ) - res2 = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res2 = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only have two results self.assertEqual(len(res2), 2) # Filter out the previous event using the event_id we grabbed above - row = [row for row in res2 if row["event_id"] != event_id] + row = [row for row in res2 if row[1] != event_id] # Check that alice's display name is now None - self.assertEqual(row[0]["display_name"], None) + self.assertIsNone(row[0][0]) def test_room_is_locally_forgotten(self) -> None: """Test that when the last local user has forgotten a room it is known as forgotten.""" diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b9446c36c..2715c73f16 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import List, Tuple, cast from immutabledict import immutabledict @@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase): ) # check that only state events are in state_groups, and all state events are in state_groups - res = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups", - keyvalues=None, - retcols=("event_id",), - ) + res = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ), ) events = [] for result in res: - self.assertNotIn(event3.event_id, result) - events.append(result.get("event_id")) + self.assertNotIn(event3.event_id, result) # XXX + events.append(result[0]) for event, _ in processed_events_and_context: if event.is_state(): @@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase): # has an entry and prev event in state_group_edges for event, context in processed_events_and_context: if event.is_state(): - state = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups_state", - keyvalues={"state_group": context.state_group_after_event}, - retcols=("type", "state_key"), - ) - ) - self.assertEqual(event.type, state[0].get("type")) - self.assertEqual(event.state_key, state[0].get("state_key")) - - groups = self.get_success( - self.store.db_pool.simple_select_list( - table="state_group_edges", - keyvalues={"state_group": str(context.state_group_after_event)}, - retcols=("*",), - ) + state = cast( + List[Tuple[str, str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ), ) - self.assertEqual( - context.state_group_before_event, groups[0].get("prev_state_group") + self.assertEqual(event.type, state[0][0]) + self.assertEqual(event.state_key, state[0][1]) + + groups = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={ + "state_group": str(context.state_group_after_event) + }, + retcols=("prev_state_group",), + ) + ), ) + self.assertEqual(context.state_group_before_event, groups[0][0]) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 8c72aa1722..822c41dd9f 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, cast from unittest import mock from unittest.mock import Mock, patch @@ -62,14 +62,13 @@ class GetUserDirectoryTables: Returns a list of tuples (user_id, room_id) where room_id is public and contains the user with the given id. """ - r = await self.store.db_pool.simple_select_list( - "users_in_public_rooms", None, ("user_id", "room_id") + r = cast( + List[Tuple[str, str]], + await self.store.db_pool.simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ), ) - - retval = set() - for i in r: - retval.add((i["user_id"], i["room_id"])) - return retval + return set(r) async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. @@ -78,27 +77,30 @@ class GetUserDirectoryTables: to the rows of `users_who_share_private_rooms`. """ - rows = await self.store.db_pool.simple_select_list( - "users_who_share_private_rooms", - None, - ["user_id", "other_user_id", "room_id"], + rows = cast( + List[Tuple[str, str, str]], + await self.store.db_pool.simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ), ) - rv = set() - for row in rows: - rv.add((row["user_id"], row["other_user_id"], row["room_id"])) - return rv + return set(rows) async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. This is useful when checking we've correctly excluded users from the directory. """ - result = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ["user_id"], + result = cast( + List[Tuple[str]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ), ) - return {row["user_id"] for row in result} + return {row[0] for row in result} async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: """Fetch users and their profiles from the `user_directory` table. @@ -107,16 +109,17 @@ class GetUserDirectoryTables: It's almost the entire contents of the `user_directory` table: the only thing missing is an unused room_id column. """ - rows = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ("user_id", "display_name", "avatar_url"), + rows = cast( + List[Tuple[str, Optional[str], Optional[str]]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ("user_id", "display_name", "avatar_url"), + ), ) return { - row["user_id"]: ProfileInfo( - display_name=row["display_name"], avatar_url=row["avatar_url"] - ) - for row in rows + user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url) + for user_id, display_name, avatar_url in rows } async def get_tables( -- cgit 1.5.1 From 679c691f6f7c4f7901e6d075a645a8ade20f44d5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Oct 2023 15:12:28 -0400 Subject: Remove more usages of cursor_to_dict. (#16551) Mostly to improve type safety. --- changelog.d/16551.misc | 1 + synapse/handlers/identity.py | 18 ++++---- synapse/handlers/ui_auth/checkers.py | 6 +-- synapse/media/media_repository.py | 5 +-- synapse/rest/admin/federation.py | 14 +++++- synapse/rest/admin/rooms.py | 12 ++++- synapse/rest/admin/statistics.py | 13 +++++- synapse/storage/database.py | 30 ++----------- synapse/storage/databases/main/censor_events.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/end_to_end_keys.py | 1 - .../storage/databases/main/events_bg_updates.py | 7 +-- .../databases/main/events_forward_extremities.py | 15 ++++--- synapse/storage/databases/main/media_repository.py | 19 ++++---- synapse/storage/databases/main/registration.py | 43 ++++++++++++------ synapse/storage/databases/main/roommember.py | 4 +- synapse/storage/databases/main/search.py | 52 +++++++++++++--------- synapse/storage/databases/main/stats.py | 15 ++++--- synapse/storage/databases/main/stream.py | 3 +- synapse/storage/databases/main/transactions.py | 28 ++++++++++-- synapse/storage/databases/main/user_directory.py | 14 +++--- synapse/storage/databases/state/bg_updates.py | 1 - tests/federation/test_federation_catch_up.py | 1 - tests/storage/test_background_update.py | 16 +++---- tests/storage/test_profile.py | 2 +- tests/storage/test_user_filters.py | 2 +- 26 files changed, 193 insertions(+), 134 deletions(-) create mode 100644 changelog.d/16551.misc (limited to 'synapse/storage/databases/main/devices.py') diff --git a/changelog.d/16551.misc b/changelog.d/16551.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16551.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 472879c964..c041b67993 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -19,6 +19,8 @@ import logging import urllib.parse from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple +import attr + from synapse.api.errors import ( CodeMessageException, Codes, @@ -357,9 +359,9 @@ class IdentityHandler: # Check to see if a session already exists and that it is not yet # marked as validated - if session and session.get("validated_at") is None: - session_id = session["session_id"] - last_send_attempt = session["last_send_attempt"] + if session and session.validated_at is None: + session_id = session.session_id + last_send_attempt = session.last_send_attempt # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -480,7 +482,6 @@ class IdentityHandler: # We don't actually know which medium this 3PID is. Thus we first assume it's email, # and if validation fails we try msisdn - validation_session = None # Try to validate as email if self.hs.config.email.can_verify_email: @@ -488,19 +489,18 @@ class IdentityHandler: validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True ) - - if validation_session: - return validation_session + if validation_session: + return attr.asdict(validation_session) # Try to validate as msisdn if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server - validation_session = await self.threepid_from_creds( + return await self.threepid_from_creds( self.hs.config.registration.account_threepid_delegate_msisdn, threepid_creds, ) - return validation_session + return None async def proxy_msisdn_submit_token( self, id_server: str, client_secret: str, sid: str, token: str diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 78a75bfed6..ab8f7610e9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker: if row: threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], + "medium": row.medium, + "address": row.address, + "validated_at": row.validated_at, } # Valid threepid returned, delete from the db diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 7fd46901f7..72b0f1c5de 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -949,10 +949,7 @@ class MediaRepository: deleted = 0 - for media in old_media: - origin = media["media_origin"] - media_id = media["media_id"] - file_id = media["filesystem_id"] + for origin, media_id, file_id in old_media: key = (origin, media_id) logger.info("Deleting: %r", key) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 8a617af599..a6ce787da1 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet): destinations, total = await self._store.get_destinations_paginate( start, limit, destination, order_by, direction ) - response = {"destinations": destinations, "total": total} + response = { + "destinations": [ + { + "destination": r[0], + "retry_last_ts": r[1], + "retry_interval": r[2], + "failure_ts": r[3], + "last_successful_stream_ordering": r[4], + } + for r in destinations + ], + "total": total, + } if (start + limit) < total: response["next_token"] = str(start + len(destinations)) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 436718c8b2..2d4da38db9 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): room_id, _ = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return HTTPStatus.OK, {"count": len(extremities), "results": extremities} + result = [ + { + "event_id": ex[0], + "state_group": ex[1], + "depth": ex[2], + "received_ts": ex[3], + } + for ex in extremities + ] + + return HTTPStatus.OK, {"count": len(extremities), "results": result} class RoomEventContextServlet(RestServlet): diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index 19780e4b4c..75d8a37ccf 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet): users_media, total = await self.store.get_users_media_usage_paginate( start, limit, from_ts, until_ts, order_by, direction, search_term ) - ret = {"users": users_media, "total": total} + ret = { + "users": [ + { + "user_id": r[0], + "displayname": r[1], + "media_count": r[2], + "media_length": r[3], + } + for r in users_media + ], + "total": total, + } if (start + limit) < total: ret["next_token"] = start + len(users_media) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 774d5c12f0..b1ece63845 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -35,7 +35,6 @@ from typing import ( Tuple, Type, TypeVar, - Union, cast, overload, ) @@ -1047,43 +1046,20 @@ class DatabasePool: results = [dict(zip(col_headers, row)) for row in cursor] return results - @overload - async def execute( - self, desc: str, decoder: Literal[None], query: str, *args: Any - ) -> List[Tuple[Any, ...]]: - ... - - @overload - async def execute( - self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any - ) -> R: - ... - - async def execute( - self, - desc: str, - decoder: Optional[Callable[[Cursor], R]], - query: str, - *args: Any, - ) -> Union[List[Tuple[Any, ...]], R]: + async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. Args: desc: description of the transaction, for logging and metrics - decoder - The function which can resolve the cursor results to - something meaningful. query - The query string to execute *args - Query args. Returns: The result of decoder(results) """ - def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: + def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]: txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() + return txn.fetchall() return await self.runInteraction(desc, interaction) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 58177ecec1..711fdddd4e 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase """ rows = await self.db_pool.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 + "_censor_redactions_fetch", sql, before_ts, 100 ) updates = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0b75f6763a..49edbb9e06 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): rows = await self.db_pool.execute( "get_all_devices_changed", - None, sql, from_key, to_key, @@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): WHERE from_user_id = ? AND stream_id > ? """ rows = await self.db_pool.execute( - "get_users_whose_signatures_changed", None, sql, user_id, from_key + "get_users_whose_signatures_changed", sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} else: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f13d776b0d..f70f95eeba 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ rows = await self.db_pool.execute( "get_e2e_device_keys_for_federation_query_check", - None, sql, now_stream_id, user_id, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index c5fce1c82b..0061805150 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the # indexes on it. - # We need to pass execute a dummy function to handle the txn's result otherwise - # it tries to call fetchall() on it and fails because there's no result to fetch. - await self.db_pool.execute( + await self.db_pool.runInteraction( "background_analyze_new_stream_ordering_column", - lambda txn: None, - "ANALYZE events(stream_ordering2)", + lambda txn: txn.execute("ANALYZE events(stream_ordering2)"), ) await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index f851bff604..0ba84b1469 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List +from typing import List, Optional, Tuple, cast from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction @@ -91,12 +91,17 @@ class EventForwardExtremitiesStore( async def get_forward_extremities_for_room( self, room_id: str - ) -> List[Dict[str, Any]]: - """Get list of forward extremities for a room.""" + ) -> List[Tuple[str, int, int, Optional[int]]]: + """ + Get list of forward extremities for a room. + + Returns: + A list of tuples of event_id, state_group, depth, and received_ts. + """ def get_forward_extremities_for_room_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, int, int, Optional[int]]]: sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities @@ -106,7 +111,7 @@ class EventForwardExtremitiesStore( """ txn.execute(sql, (room_id,)) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall()) return await self.db_pool.runInteraction( "get_forward_extremities_for_room", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index f82140b2e8..aeb3db596c 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_ids( self, before_ts: int, include_quarantined_media: bool - ) -> List[Dict[str, str]]: + ) -> List[Tuple[str, str, str]]: """ Retrieve a list of server name, media ID tuples from the remote media cache. @@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): A list of tuples containing: * The server name of homeserver where the media originates from, * The ID of the media. + * The filesystem ID. + """ + + sql = """ + SELECT media_origin, media_id, filesystem_id + FROM remote_media_cache + WHERE last_access_ts < ? """ - sql = ( - "SELECT media_origin, media_id, filesystem_id" - " FROM remote_media_cache" - " WHERE last_access_ts < ?" - ) if include_quarantined_media is False: # Only include media that has not been quarantined @@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): AND quarantined_by IS NULL """ - return await self.db_pool.execute( - "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts + return cast( + List[Tuple[str, str, str]], + await self.db_pool.execute("get_remote_media_ids", sql, before_ts), ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index b0ef7be155..e09ab21593 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -151,6 +151,22 @@ class ThreepidResult: added_at: int +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThreepidValidationSession: + address: str + """address of the 3pid""" + medium: str + """medium of the 3pid""" + client_secret: str + """a secret provided by the client for this validation session""" + session_id: str + """ID of the validation session""" + last_send_attempt: int + """a number serving to dedupe send attempts for this session""" + validated_at: Optional[int] + """timestamp of when this session was validated if so""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): address: Optional[str] = None, sid: Optional[str] = None, validated: Optional[bool] = True, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThreepidValidationSession]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata @@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): perform no filtering Returns: - A dict containing the following: - * address - address of the 3pid - * medium - medium of the 3pid - * client_secret - a secret provided by the client for this validation session - * session_id - ID of the validation session - * send_attempt - a number serving to dedupe send attempts for this session - * validated_at - timestamp of when this session was validated if so - - Otherwise None if a validation session is not found + A ThreepidValidationSession or None if a validation session is not found """ if not client_secret: raise SynapseError( @@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): def get_threepid_validation_session_txn( txn: LoggingTransaction, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThreepidValidationSession]: sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at @@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.db_pool.cursor_to_dict(txn) - if not rows: + row = txn.fetchone() + if not row: return None - return rows[0] + return ThreepidValidationSession( + address=row[0], + session_id=row[1], + medium=row[2], + client_secret=row[3], + last_send_attempt=row[4], + validated_at=row[5], + ) return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index a1627dffb7..67e149b586 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, membership, room_id, like_clause + "is_host_joined", sql, membership, room_id, like_clause ) if not rows: @@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND forgotten = 0; """ - rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + rows = await self.db_pool.execute("is_forgotten_room", sql, room_id) # `count(*)` returns always an integer # If any rows still exist it means someone has not forgotten this room yet diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1d69c4a5f0..dbde9130c6 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -26,6 +26,7 @@ from typing import ( Set, Tuple, Union, + cast, ) import attr @@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = await self.db_pool.execute( - "search_msgs", self.db_pool.cursor_to_dict, sql, *args + # List of tuples of (rank, room_id, event_id). + results = cast( + List[Tuple[Union[int, float], str, str]], + await self.db_pool.execute("search_msgs", sql, *args), ) - results = list(filter(lambda row: row["room_id"] in room_ids, results)) + results = list(filter(lambda row: row[1] in room_ids, results)) # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] - [r["event_id"] for r in results], + [r[2] for r in results], redact_behaviour=EventRedactBehaviour.block, ) @@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = await self.db_pool.execute( - "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + # List of tuples of (room_id, count). + count_results = cast( + List[Tuple[str, int]], + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + count = sum(row[1] for row in count_results if row[0] in room_ids) return { "results": [ - {"event": event_map[r["event_id"]], "rank": r["rank"]} + {"event": event_map[r[2]], "rank": r[0]} for r in results - if r["event_id"] in event_map + if r[2] in event_map ], "highlights": highlights, "count": count, @@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = search_term sql = """ SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, - origin_server_ts, stream_ordering, room_id, event_id + room_id, event_id, origin_server_ts, stream_ordering FROM event_search WHERE vector @@ websearch_to_tsquery('english', ?) AND """ @@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore): # mypy expects to append only a `str`, not an `int` args.append(limit) - results = await self.db_pool.execute( - "search_rooms", self.db_pool.cursor_to_dict, sql, *args + # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering). + results = cast( + List[Tuple[Union[int, float], str, str, int, int]], + await self.db_pool.execute("search_rooms", sql, *args), ) - results = list(filter(lambda row: row["room_id"] in room_ids, results)) + results = list(filter(lambda row: row[1] in room_ids, results)) # We set redact_behaviour to block here to prevent redacted events being returned in # search results (which is a data leak) events = await self.get_events_as_list( # type: ignore[attr-defined] - [r["event_id"] for r in results], + [r[2] for r in results], redact_behaviour=EventRedactBehaviour.block, ) @@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = await self.db_pool.execute( - "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args + # List of tuples of (room_id, count). + count_results = cast( + List[Tuple[str, int]], + await self.db_pool.execute("search_rooms_count", count_sql, *count_args), ) - count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + count = sum(row[1] for row in count_results if row[0] in room_ids) return { "results": [ { - "event": event_map[r["event_id"]], - "rank": r["rank"], - "pagination_token": "%s,%s" - % (r["origin_server_ts"], r["stream_ordering"]), + "event": event_map[r[2]], + "rank": r[0], + "pagination_token": "%s,%s" % (r[3], r[4]), } for r in results - if r["event_id"] in event_map + if r[2] in event_map ], "highlights": highlights, "count": count, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5b2d0ba870..e96c9b0486 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore): order_by: Optional[str] = UserSortOrder.USER_ID.value, direction: Direction = Direction.FORWARDS, search_term: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. @@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore): order_by: the sort order of the returned list direction: sort ascending or descending search_term: a string to filter user names by + Returns: - A list of user dicts and an integer representing the total number of - users that exist given this query + A tuple of: + A list of tuples of user information (the user ID, displayname, + total number of media, total length of media) and + + An integer representing the total number of users that exist + given this query """ def get_users_media_usage_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]: filters = [] args: list = [] @@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore): args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) + users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall()) return users, count diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 872df6bda1..2225f8272d 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ row = await self.db_pool.execute( - "get_current_topological_token", None, sql, room_id, room_id, stream_key + "get_current_topological_token", sql, room_id, room_id, stream_key ) return row[0][0] if row else 0 @@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = await self.db_pool.execute( "get_timeline_gaps", - None, sql, room_id, from_token.stream if from_token else 0, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index c4a6475060..fecddb4144 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destination: Optional[str] = None, order_by: str = DestinationSortOrder.DESTINATION.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[ + List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]], + int, + ]: """Function to retrieve a paginated list of destinations. This will return a json list of destinations and the total number of destinations matching the filter criteria. @@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): order_by: the sort order of the returned list direction: sort ascending or descending Returns: - A tuple of a list of mappings from destination to information + A tuple of a list of tuples of destination information: + * destination + * retry_last_ts + * retry_interval + * failure_ts + * last_successful_stream_ordering and a count of total destinations. """ def get_destinations_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[ + List[ + Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]] + ], + int, + ]: order_by_column = DestinationSortOrder(order_by).value if direction == Direction.BACKWARDS: @@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): LIMIT ? OFFSET ? """ txn.execute(sql, args + [limit, start]) - destinations = self.db_pool.cursor_to_dict(txn) + destinations = cast( + List[ + Tuple[ + str, Optional[int], Optional[int], Optional[int], Optional[int] + ] + ], + txn.fetchall(), + ) return destinations, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 23eb92c514..a9f5d68b63 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): raise Exception("Unrecognized database engine") results = cast( - List[UserProfile], - await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args - ), + List[Tuple[str, Optional[str], Optional[str]]], + await self.db_pool.execute("search_user_dir", sql, *args), ) limited = len(results) > limit - return {"limited": limited, "results": results[0:limit]} + return { + "limited": limited, + "results": [ + {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]} + for r in results[0:limit] + ], + } def _filter_text_for_index(text: str) -> str: diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 6ff533a129..0f9c550b27 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): if max_group is None: rows = await self.db_pool.execute( "_background_deduplicate_state", - None, "SELECT coalesce(max(id), 0) FROM state_groups", ) max_group = rows[0][0] diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 75ae740b43..08214b0013 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): event_id, stream_ordering = self.get_success( self.hs.get_datastores().main.db_pool.execute( "test:get_destination_rooms", - None, """ SELECT event_id, stream_ordering FROM destination_rooms dr diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 3f5bfa09d4..67ea640902 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): ); """ self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(table_sql) ) ) @@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): # using SQLite. index_sql = "CREATE INDEX test_index ON test_constraint(a)" self.get_success( - self.store.db_pool.execute( - "test_not_null_constraint", lambda _: None, index_sql + self.store.db_pool.runInteraction( + "test_not_null_constraint", lambda txn: txn.execute(index_sql) ) ) @@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): ); """ self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, base_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(base_sql) ) ) self.get_success( - self.store.db_pool.execute( - "test_foreign_key_constraint", lambda _: None, table_sql + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", lambda txn: txn.execute(table_sql) ) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 95f99f4130..6afb5403bd 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from profiles ORDER BY full_user_id" + "", "SELECT full_user_id from profiles ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values)) diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py index d4637d9d1e..2da6a018e8 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py @@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase): res = self.get_success( self.store.db_pool.execute( - "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id" + "", "SELECT full_user_id from user_filters ORDER BY full_user_id" ) ) self.assertEqual(len(res), len(expected_values)) -- cgit 1.5.1 From cfb6d38c47711b8dfaf0125353aec88d16708b97 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 Oct 2023 13:13:28 -0400 Subject: Remove remaining usage of cursor_to_dict. (#16564) --- changelog.d/16564.misc | 1 + synapse/handlers/admin.py | 2 +- synapse/handlers/room_list.py | 43 ++++++------ synapse/handlers/room_summary.py | 26 +++---- synapse/rest/admin/media.py | 6 +- synapse/rest/admin/registration_tokens.py | 13 +++- synapse/rest/admin/rooms.py | 11 ++- synapse/rest/admin/users.py | 10 ++- synapse/storage/background_updates.py | 14 ++-- synapse/storage/database.py | 15 ---- synapse/storage/databases/main/__init__.py | 52 +++++++++++--- synapse/storage/databases/main/devices.py | 55 ++++++++++----- synapse/storage/databases/main/media_repository.py | 48 ++++++++++--- synapse/storage/databases/main/registration.py | 42 +++++++---- synapse/storage/databases/main/room.py | 82 ++++++++++++++++++---- tests/handlers/test_register.py | 22 +++--- tests/storage/test_main.py | 4 +- tests/storage/test_room.py | 11 +-- 18 files changed, 300 insertions(+), 157 deletions(-) create mode 100644 changelog.d/16564.misc (limited to 'synapse/storage/databases/main/devices.py') diff --git a/changelog.d/16564.misc b/changelog.d/16564.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16564.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2c2baeac67..d06f8e3296 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -283,7 +283,7 @@ class AdminHandler: start, limit, user_id ) for media in media_ids: - writer.write_media_id(media["media_id"], media) + writer.write_media_id(media.media_id, attr.asdict(media)) logger.info( "[%s] Written %d media_ids of %s", diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 36e2db8975..2947e154be 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.storage.databases.main.room import LargestRoomStats from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -170,26 +171,24 @@ class RoomListHandler: ignore_non_federatable=from_federation, ) - def build_room_entry(room: JsonDict) -> JsonDict: + def build_room_entry(room: LargestRoomStats) -> JsonDict: entry = { - "room_id": room["room_id"], - "name": room["name"], - "topic": room["topic"], - "canonical_alias": room["canonical_alias"], - "num_joined_members": room["joined_members"], - "avatar_url": room["avatar"], - "world_readable": room["history_visibility"] + "room_id": room.room_id, + "name": room.name, + "topic": room.topic, + "canonical_alias": room.canonical_alias, + "num_joined_members": room.joined_members, + "avatar_url": room.avatar, + "world_readable": room.history_visibility == HistoryVisibility.WORLD_READABLE, - "guest_can_join": room["guest_access"] == "can_join", - "join_rule": room["join_rules"], - "room_type": room["room_type"], + "guest_can_join": room.guest_access == "can_join", + "join_rule": room.join_rules, + "room_type": room.room_type, } # Filter out Nones – rather omit the field altogether return {k: v for k, v in entry.items() if v is not None} - results = [build_room_entry(r) for r in results] - response: JsonDict = {} num_results = len(results) if limit is not None: @@ -212,33 +211,33 @@ class RoomListHandler: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() if more_to_come: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() else: if has_batch_token: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, ).to_token() if more_to_come: response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, ).to_token() - response["chunk"] = results + response["chunk"] = [build_room_entry(r) for r in results] response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index dd559b4c45..1dfb12e065 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -703,24 +703,24 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rule": stats["join_rules"], + entry: JsonDict = { + "room_id": stats.room_id, + "name": stats.name, + "topic": stats.topic, + "canonical_alias": stats.canonical_alias, + "num_joined_members": stats.joined_members, + "avatar_url": stats.avatar, + "join_rule": stats.join_rules, "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + stats.history_visibility == HistoryVisibility.WORLD_READABLE ), - "guest_can_join": stats["guest_access"] == "can_join", - "room_type": stats["room_type"], + "guest_can_join": stats.guest_access == "can_join", + "room_type": stats.room_type, } if self._msc3266_enabled: - entry["im.nheko.summary.version"] = stats["version"] - entry["im.nheko.summary.encryption"] = stats["encryption"] + entry["im.nheko.summary.version"] = stats.version + entry["im.nheko.summary.encryption"] = stats.encryption # Federation requests need to provide additional information so the # requested server is able to filter the response appropriately. diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index b7637dff0b..8cf5268854 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,6 +17,8 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple +import attr + from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet): start, limit, user_id, order_by, direction ) - ret = {"media": media, "total": total} + ret = {"media": [attr.asdict(m) for m in media], "total": total} if (start + limit) < total: ret["next_token"] = start + len(media) @@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet): ) deleted_media, total = await self.media_repository.delete_local_media_ids( - [row["media_id"] for row in media] + [m.media_id for m in media] ) return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index ffce92d45e..f3e06d3da3 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) valid = parse_boolean(request, "valid") token_list = await self.store.get_registration_tokens(valid) - return HTTPStatus.OK, {"registration_tokens": token_list} + return HTTPStatus.OK, { + "registration_tokens": [ + { + "token": t[0], + "uses_allowed": t[1], + "pending": t[2], + "completed": t[3], + "expiry_time": t[4], + } + for t in token_list + ] + } class NewRegistrationTokenRestServlet(RestServlet): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 0659f22a89..23a034522c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -16,6 +16,8 @@ from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast from urllib import parse as urlparse +import attr + from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter @@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet): raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) - ret["joined_local_devices"] = await self.store.count_devices_by_users(members) - ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) + result = attr.asdict(ret) + result["joined_local_devices"] = await self.store.count_devices_by_users( + members + ) + result["forgotten"] = await self.store.is_locally_forgotten_room(room_id) - return HTTPStatus.OK, ret + return HTTPStatus.OK, result async def on_DELETE( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7fe16130e7..73878dd99d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -18,6 +18,8 @@ import secrets from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import attr + from synapse.api.constants import Direction, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( @@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet): ) # If support for MSC3866 is not enabled, don't show the approval flag. + filter = None if not self._msc3866_enabled: - for user in users: - del user["approved"] - ret = {"users": users, "total": total} + def _filter(a: attr.Attribute) -> bool: + return a.name != "approved" + + ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 12829d3d7d..7426dbcad6 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -28,6 +28,7 @@ from typing import ( Sequence, Tuple, Type, + cast, ) import attr @@ -488,14 +489,14 @@ class BackgroundUpdater: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: + def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates ORDER BY ordering, update_name """ ) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, Optional[str]]], txn.fetchall()) if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( @@ -507,14 +508,13 @@ class BackgroundUpdater: return True # find the first update which isn't dependent on another one in the queue. - pending = {update["update_name"] for update in all_pending_updates} - for upd in all_pending_updates: - depends_on = upd["depends_on"] + pending = {update_name for update_name, depends_on in all_pending_updates} + for update_name, depends_on in all_pending_updates: if not depends_on or depends_on not in pending: break logger.info( "Not starting on bg update %s until %s is done", - upd["update_name"], + update_name, depends_on, ) else: @@ -524,7 +524,7 @@ class BackgroundUpdater: "another: dependency cycle?" ) - self._current_background_update = upd["update_name"] + self._current_background_update = update_name # We have a background update to run, otherwise we would have returned # early. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a4e7048368..6d54bb0eb2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -18,7 +18,6 @@ import logging import time import types from collections import defaultdict -from sys import intern from time import monotonic as monotonic_time from typing import ( TYPE_CHECKING, @@ -1042,20 +1041,6 @@ class DatabasePool: self._db_pool.runWithConnection(inner_func, *args, **kwargs) ) - @staticmethod - def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: - """Converts a SQL cursor into an list of dicts. - - Args: - cursor: The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - assert cursor.description is not None, "cursor.description was None" - col_headers = [intern(str(column[0])) for column in cursor.description] - results = [dict(zip(col_headers, row)) for row in cursor] - return results - async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]: """Runs a single query for a result set. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 840d725114..89f4077351 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,6 +17,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import attr + from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage._base import make_in_list_sql_clause @@ -28,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import get_domain_from_id from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -82,6 +84,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPaginateResponse: + """This is very similar to UserInfo, but not quite the same.""" + + name: str + user_type: Optional[str] + is_guest: bool + admin: bool + deactivated: bool + shadow_banned: bool + displayname: Optional[str] + avatar_url: Optional[str] + creation_ts: Optional[int] + approved: bool + erased: bool + last_seen_ts: int + locked: bool + + class DataStore( EventsBackgroundUpdatesStore, ExperimentalFeaturesStore, @@ -156,7 +177,7 @@ class DataStore( approved: bool = True, not_user_types: Optional[List[str]] = None, locked: bool = False, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -182,7 +203,7 @@ class DataStore( def get_users_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: filters = [] args: list = [] @@ -282,13 +303,24 @@ class DataStore( """ args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) - - # some of those boolean values are returned as integers when we're on SQLite - columns_to_boolify = ["erased"] - for user in users: - for column in columns_to_boolify: - user[column] = bool(user[column]) + users = [ + UserPaginateResponse( + name=row[0], + user_type=row[1], + is_guest=bool(row[2]), + admin=bool(row[3]), + deactivated=bool(row[4]), + shadow_banned=bool(row[5]), + displayname=row[6], + avatar_url=row[7], + creation_ts=row[8], + approved=bool(row[9]), + erased=bool(row[10]), + last_seen_ts=row[11], + locked=bool(row[12]), + ) + for row in txn + ] return users, count diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 49edbb9e06..b0811a4cf1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # # For each duplicate, we delete all the existing rows and put one back. - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] last_row = progress.get( "last_row", {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, @@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( - [(x, last_row[x]) for x in KEY_COLS] + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] ) - sql = """ + sql = f""" SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id HAVING count(*) > 1 - ORDER BY %s + ORDER BY stream_id, destination, user_id, device_id LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) + """ txn.execute(sql, args + [batch_size]) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() - row = None - for row in rows: + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", - {x: row[x] for x in KEY_COLS}, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, ) - row["sent"] = False self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", - row, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, ) - if row: + if rows: self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - {"last_row": row}, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, ) return len(rows) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index aeb3db596c..c8d7c9fd32 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo @@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: int + upload_name: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + last_access_ts, + quarantined_by, + safe_from_quarantine FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + safe_from_quarantine=bool(row[7]), + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e09ab21593..933d76e905 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def get_registration_tokens( self, valid: Optional[bool] = None - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: """List all registration tokens. Used by the admin API. Args: @@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Default is None: return all tokens regardless of validity. Returns: - A list of dicts, each containing details of a token. + A list of tuples containing: + * The token + * The number of users allowed (or None) + * Whether it is pending + * Whether it has been completed + * An expiry time (or None if no expiry) """ def select_registration_tokens_txn( txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: if valid is None: # Return all tokens regardless of validity - txn.execute("SELECT * FROM registration_tokens") + txn.execute( + """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + """ + ) elif valid: # Select valid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "(uses_allowed > pending + completed OR uses_allowed IS NULL) " - "AND (expiry_time > ? OR expiry_time IS NULL)" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) + AND (expiry_time > ? OR expiry_time IS NULL) + """ txn.execute(sql, [now]) else: # Select invalid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "uses_allowed <= pending + completed OR expiry_time <= ?" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE uses_allowed <= pending + completed OR expiry_time <= ? + """ txn.execute(sql, [now]) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + ) return await self.db_pool.runInteraction( "select_registration_tokens", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3e8fcf1975..6d4b9891e7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -78,6 +78,31 @@ class RatelimitOverride: burst_count: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LargestRoomStats: + room_id: str + name: Optional[str] + canonical_alias: Optional[str] + joined_members: int + join_rules: Optional[str] + guest_access: Optional[str] + history_visibility: Optional[str] + state_events: int + avatar: Optional[str] + topic: Optional[str] + room_type: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomStats(LargestRoomStats): + joined_local_members: int + version: Optional[str] + creator: Optional[str] + encryption: Optional[str] + federatable: bool + public: bool + + class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate @@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) - async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. Args: @@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[RoomStats]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db_pool.cursor_to_dict(txn)[0] - except IndexError: + row = txn.fetchone() + if not row: return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res + return RoomStats( + room_id=row[0], + name=row[1], + canonical_alias=row[2], + joined_members=row[3], + joined_local_members=row[4], + version=row[5], + creator=row[6], + encryption=row[7], + federatable=bool(row[8]), + public=bool(row[9]), + join_rules=row[10], + guest_access=row[11], + history_visibility=row[12], + state_events=row[13], + avatar=row[14], + topic=row[15], + room_type=row[16], + ) return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id @@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_largest_public_rooms_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: txn.execute(sql, query_args) - results = self.db_pool.cursor_to_dict(txn) + results = [ + LargestRoomStats( + room_id=r[0], + name=r[1], + canonical_alias=r[3], + joined_members=r[4], + join_rules=r[8], + guest_access=r[7], + history_visibility=r[6], + state_events=0, + avatar=r[5], + topic=r[2], + room_type=r[9], + ) + for r in txn + ] if not forwards: results.reverse() return results - ret_val = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e9fbf32c7c..032b89d684 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -342,10 +342,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly not federated. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) assert room is not None - self.assertFalse(room["federatable"]) - self.assertFalse(room["public"]) - self.assertEqual(room["join_rules"], "public") - self.assertIsNone(room["guest_access"]) + self.assertFalse(room.federatable) + self.assertFalse(room.public) + self.assertEqual(room.join_rules, "public") + self.assertIsNone(room.guest_access) # The user should be in the room. rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -372,7 +372,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a public room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) assert room is not None - self.assertEqual(room["join_rules"], "public") + self.assertEqual(room.join_rules, "public") # Both users should be in the room. rooms = self.get_success(self.store.get_rooms_for_user(inviter)) @@ -411,9 +411,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) assert room is not None - self.assertFalse(room["public"]) - self.assertEqual(room["join_rules"], "invite") - self.assertEqual(room["guest_access"], "can_join") + self.assertFalse(room.public) + self.assertEqual(room.join_rules, "invite") + self.assertEqual(room.guest_access, "can_join") # Both users should be in the room. rooms = self.get_success(self.store.get_rooms_for_user(inviter)) @@ -455,9 +455,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) assert room is not None - self.assertFalse(room["public"]) - self.assertEqual(room["join_rules"], "invite") - self.assertEqual(room["guest_access"], "can_join") + self.assertFalse(room.public) + self.assertEqual(room.join_rules, "invite") + self.assertEqual(room.guest_access, "can_join") # Both users should be in the room. rooms = self.get_success(self.store.get_rooms_for_user(inviter)) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index b8823d6993..01c0e5e671 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -39,11 +39,11 @@ class DataStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(1, total) - self.assertEqual(self.displayname, users.pop()["displayname"]) + self.assertEqual(self.displayname, users.pop().displayname) users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) self.assertEqual(1, total) - self.assertEqual(self.displayname, users.pop()["displayname"]) + self.assertEqual(self.displayname, users.pop().displayname) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1e27f2c275..ce34195a25 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -59,14 +59,9 @@ class RoomStoreTestCase(HomeserverTestCase): def test_get_room_with_stats(self) -> None: res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) assert res is not None - self.assertLessEqual( - { - "room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "public": True, - }.items(), - res.items(), - ) + self.assertEqual(res.room_id, self.room.to_string()) + self.assertEqual(res.creator, self.u_creator.to_string()) + self.assertTrue(res.public) def test_get_room_with_stats_unknown_room(self) -> None: self.assertIsNone( -- cgit 1.5.1 From 9738b1c4975b293a1bc25ee27b5527724038baa1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Nov 2023 14:00:25 -0500 Subject: Avoid executing no-op queries. (#16583) If simple_{insert,upsert,update}_many_txn is called without any data to modify then return instead of executing the query. This matches the behavior of simple_{select,delete}_many_txn. --- changelog.d/16583.misc | 1 + synapse/storage/database.py | 32 ++++++++++++++++++++++--------- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/events.py | 12 ++++++------ synapse/storage/databases/main/room.py | 2 +- synapse/storage/databases/main/search.py | 4 ++-- tests/storage/test_base.py | 25 +++++------------------- 7 files changed, 39 insertions(+), 39 deletions(-) create mode 100644 changelog.d/16583.misc (limited to 'synapse/storage/databases/main/devices.py') diff --git a/changelog.d/16583.misc b/changelog.d/16583.misc new file mode 100644 index 0000000000..df5b27b112 --- /dev/null +++ b/changelog.d/16583.misc @@ -0,0 +1 @@ +Avoid executing no-op queries. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index abc7d8a5d2..792f2e7cdf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1117,7 +1117,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keys: Collection[str], - values: Iterable[Iterable[Any]], + values: Collection[Iterable[Any]], ) -> None: """Executes an INSERT query on the named table. @@ -1130,6 +1130,9 @@ class DatabasePool: keys: list of column names values: for each row, a list of values in the same order as `keys` """ + # If there's nothing to insert, then skip executing the query. + if not values: + return if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, @@ -1455,7 +1458,7 @@ class DatabasePool: key_names: Collection[str], key_values: Collection[Iterable[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[Any]], + value_values: Collection[Iterable[Any]], ) -> None: """ Upsert, many times. @@ -1468,6 +1471,19 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ + # If there's nothing to upsert, then skip executing the query. + if not key_values: + return + + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + elif len(value_values) != len(key_values): + raise ValueError( + f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." + ) + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values @@ -1502,10 +1518,6 @@ class DatabasePool: value_values: A list of each row's value column values. Ignored if value_names is empty. """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] # Lock the table just once, to prevent it being done once per row. # Note that, according to Postgres' documentation, once obtained, @@ -1543,10 +1555,7 @@ class DatabasePool: allnames.extend(value_names) if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. latter = "NOTHING" - value_values = [() for x in range(len(key_values))] else: latter = "UPDATE SET " + ", ".join( k + "=EXCLUDED." + k for k in value_names @@ -1910,6 +1919,7 @@ class DatabasePool: Returns: The results as a list of tuples. """ + # If there's nothing to select, then skip executing the query. if not iterable: return [] @@ -2044,6 +2054,9 @@ class DatabasePool: raise ValueError( f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." ) + # If there is nothing to update, then skip executing the query. + if not key_values: + return # List of tuples of (value values, then key values) # (This matches the order needed for the query) @@ -2278,6 +2291,7 @@ class DatabasePool: Returns: Number rows deleted """ + # If there's nothing to delete, then skip executing the query. if not values: return 0 diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b0811a4cf1..04d12a876c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -703,7 +703,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): key_names=("destination", "user_id"), key_values=[(destination, user_id) for user_id, _ in rows], value_names=("stream_id",), - value_values=((stream_id,) for _, stream_id in rows), + value_values=[(stream_id,) for _, stream_id in rows], ) # Delete all sent outbound pokes diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 647ba182f6..7c34bde3e5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1476,7 +1476,7 @@ class PersistEventsStore: txn, table="event_json", keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), - values=( + values=[ ( event.event_id, event.room_id, @@ -1485,7 +1485,7 @@ class PersistEventsStore: event.format_version, ) for event, _ in events_and_contexts - ), + ], ) self.db_pool.simple_insert_many_txn( @@ -1508,7 +1508,7 @@ class PersistEventsStore: "state_key", "rejection_reason", ), - values=( + values=[ ( self._instance_name, event.internal_metadata.stream_ordering, @@ -1527,7 +1527,7 @@ class PersistEventsStore: context.rejected, ) for event, context in events_and_contexts - ), + ], ) # If we're persisting an unredacted event we go and ensure @@ -1550,11 +1550,11 @@ class PersistEventsStore: txn, table="state_events", keys=("event_id", "room_id", "type", "state_key"), - values=( + values=[ (event.event_id, event.room_id, event.type, event.state_key) for event, _ in events_and_contexts if event.is_state() - ), + ], ) def _store_rejected_events_txn( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6d4b9891e7..afb880532e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2268,7 +2268,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): txn, table="partial_state_rooms_servers", keys=("room_id", "server_name"), - values=((room_id, s) for s in servers), + values=[(room_id, s) for s in servers], ) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream( diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index dbde9130c6..f4bef4c99b 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -106,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore): txn, table="event_search", keys=("event_id", "room_id", "key", "value"), - values=( + values=[ ( entry.event_id, entry.room_id, @@ -114,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore): _clean_value_for_search(entry.value), ) for entry in entries - ), + ], ) else: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index b4c490b568..de4fcfe026 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -189,17 +189,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) if USE_POSTGRES_FOR_TESTS: - self.mock_execute_values.assert_called_once_with( - self.mock_txn, - "INSERT INTO tablename (col1, col2) VALUES ?", - [], - template=None, - fetch=False, - ) + self.mock_execute_values.assert_not_called() else: - self.mock_txn.executemany.assert_called_once_with( - "INSERT INTO tablename (col1, col2) VALUES(?, ?)", [] - ) + self.mock_txn.executemany.assert_not_called() @defer.inlineCallbacks def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: @@ -393,7 +385,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_update_many_no_values( + def test_update_many_no_iterable( self, ) -> Generator["defer.Deferred[object]", object, None]: yield defer.ensureDeferred( @@ -408,16 +400,9 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) if USE_POSTGRES_FOR_TESTS: - self.mock_execute_batch.assert_called_once_with( - self.mock_txn, - "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", - [], - ) + self.mock_execute_batch.assert_not_called() else: - self.mock_txn.executemany.assert_called_once_with( - "UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?", - [], - ) + self.mock_txn.executemany.assert_not_called() @defer.inlineCallbacks def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: -- cgit 1.5.1