diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2023-10-31 13:13:28 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-31 13:13:28 -0400 |
commit | cfb6d38c47711b8dfaf0125353aec88d16708b97 (patch) | |
tree | 5376fba887e841c9574b5ee444719560e5c47135 /synapse/storage/databases | |
parent | Merge branch 'release-v1.96' into develop (diff) | |
download | synapse-cfb6d38c47711b8dfaf0125353aec88d16708b97.tar.xz |
Remove remaining usage of cursor_to_dict. (#16564)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/__init__.py | 52 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 55 | ||||
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 48 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 42 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 82 |
5 files changed, 210 insertions, 69 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 840d725114..89f4077351 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,6 +17,8 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import attr + from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig from synapse.storage._base import make_in_list_sql_clause @@ -28,7 +30,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import get_domain_from_id from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore @@ -82,6 +84,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPaginateResponse: + """This is very similar to UserInfo, but not quite the same.""" + + name: str + user_type: Optional[str] + is_guest: bool + admin: bool + deactivated: bool + shadow_banned: bool + displayname: Optional[str] + avatar_url: Optional[str] + creation_ts: Optional[int] + approved: bool + erased: bool + last_seen_ts: int + locked: bool + + class DataStore( EventsBackgroundUpdatesStore, ExperimentalFeaturesStore, @@ -156,7 +177,7 @@ class DataStore( approved: bool = True, not_user_types: Optional[List[str]] = None, locked: bool = False, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. @@ -182,7 +203,7 @@ class DataStore( def get_users_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[UserPaginateResponse], int]: filters = [] args: list = [] @@ -282,13 +303,24 @@ class DataStore( """ args += [limit, start] txn.execute(sql, args) - users = self.db_pool.cursor_to_dict(txn) - - # some of those boolean values are returned as integers when we're on SQLite - columns_to_boolify = ["erased"] - for user in users: - for column in columns_to_boolify: - user[column] = bool(user[column]) + users = [ + UserPaginateResponse( + name=row[0], + user_type=row[1], + is_guest=bool(row[2]), + admin=bool(row[3]), + deactivated=bool(row[4]), + shadow_banned=bool(row[5]), + displayname=row[6], + avatar_url=row[7], + creation_ts=row[8], + approved=bool(row[9]), + erased=bool(row[10]), + last_seen_ts=row[11], + locked=bool(row[12]), + ) + for row in txn + ] return users, count diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 49edbb9e06..b0811a4cf1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # # For each duplicate, we delete all the existing rows and put one back. - KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] last_row = progress.get( "last_row", {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, @@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): def _txn(txn: LoggingTransaction) -> int: clause, args = make_tuple_comparison_clause( - [(x, last_row[x]) for x in KEY_COLS] + [ + ("stream_id", last_row["stream_id"]), + ("destination", last_row["destination"]), + ("user_id", last_row["user_id"]), + ("device_id", last_row["device_id"]), + ] ) - sql = """ + sql = f""" SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts FROM device_lists_outbound_pokes - WHERE %s - GROUP BY %s + WHERE {clause} + GROUP BY stream_id, destination, user_id, device_id HAVING count(*) > 1 - ORDER BY %s + ORDER BY stream_id, destination, user_id, device_id LIMIT ? - """ % ( - clause, # WHERE - ",".join(KEY_COLS), # GROUP BY - ",".join(KEY_COLS), # ORDER BY - ) + """ txn.execute(sql, args + [batch_size]) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() - row = None - for row in rows: + stream_id, destination, user_id, device_id = None, None, None, None + for stream_id, destination, user_id, device_id, _ in rows: self.db_pool.simple_delete_txn( txn, "device_lists_outbound_pokes", - {x: row[x] for x in KEY_COLS}, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + }, ) - row["sent"] = False self.db_pool.simple_insert_txn( txn, "device_lists_outbound_pokes", - row, + { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + "sent": False, + }, ) - if row: + if rows: self.db_pool.updates._background_update_progress_txn( txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, - {"last_row": row}, + { + "last_row": { + "stream_id": stream_id, + "destination": destination, + "user_id": user_id, + "device_id": device_id, + } + }, ) return len(rows) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index aeb3db596c..c8d7c9fd32 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace from synapse.media._base import ThumbnailInfo @@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: int + upload_name: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + last_access_ts, + quarantined_by, + safe_from_quarantine FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + safe_from_quarantine=bool(row[7]), + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e09ab21593..933d76e905 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def get_registration_tokens( self, valid: Optional[bool] = None - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: """List all registration tokens. Used by the admin API. Args: @@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Default is None: return all tokens regardless of validity. Returns: - A list of dicts, each containing details of a token. + A list of tuples containing: + * The token + * The number of users allowed (or None) + * Whether it is pending + * Whether it has been completed + * An expiry time (or None if no expiry) """ def select_registration_tokens_txn( txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]: if valid is None: # Return all tokens regardless of validity - txn.execute("SELECT * FROM registration_tokens") + txn.execute( + """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + """ + ) elif valid: # Select valid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "(uses_allowed > pending + completed OR uses_allowed IS NULL) " - "AND (expiry_time > ? OR expiry_time IS NULL)" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL) + AND (expiry_time > ? OR expiry_time IS NULL) + """ txn.execute(sql, [now]) else: # Select invalid tokens only - sql = ( - "SELECT * FROM registration_tokens WHERE " - "uses_allowed <= pending + completed OR expiry_time <= ?" - ) + sql = """ + SELECT token, uses_allowed, pending, completed, expiry_time + FROM registration_tokens + WHERE uses_allowed <= pending + completed OR expiry_time <= ? + """ txn.execute(sql, [now]) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + ) return await self.db_pool.runInteraction( "select_registration_tokens", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3e8fcf1975..6d4b9891e7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -78,6 +78,31 @@ class RatelimitOverride: burst_count: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LargestRoomStats: + room_id: str + name: Optional[str] + canonical_alias: Optional[str] + joined_members: int + join_rules: Optional[str] + guest_access: Optional[str] + history_visibility: Optional[str] + state_events: int + avatar: Optional[str] + topic: Optional[str] + room_type: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomStats(LargestRoomStats): + joined_local_members: int + version: Optional[str] + creator: Optional[str] + encryption: Optional[str] + federatable: bool + public: bool + + class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate @@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) - async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. Args: @@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[RoomStats]: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE room_id = ? """ txn.execute(sql, [room_id]) - # Catch error if sql returns empty result to return "None" instead of an error - try: - res = self.db_pool.cursor_to_dict(txn)[0] - except IndexError: + row = txn.fetchone() + if not row: return None - - res["federatable"] = bool(res["federatable"]) - res["public"] = bool(res["public"]) - return res + return RoomStats( + room_id=row[0], + name=row[1], + canonical_alias=row[2], + joined_members=row[3], + joined_local_members=row[4], + version=row[5], + creator=row[6], + encryption=row[7], + federatable=bool(row[8]), + public=bool(row[9]), + join_rules=row[10], + guest_access=row[11], + history_visibility=row[12], + state_events=row[13], + avatar=row[14], + topic=row[15], + room_type=row[16], + ) return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id @@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): bounds: Optional[Tuple[int, str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). @@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_largest_public_rooms_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Any]]: + ) -> List[LargestRoomStats]: txn.execute(sql, query_args) - results = self.db_pool.cursor_to_dict(txn) + results = [ + LargestRoomStats( + room_id=r[0], + name=r[1], + canonical_alias=r[3], + joined_members=r[4], + join_rules=r[8], + guest_access=r[7], + history_visibility=r[6], + state_events=0, + avatar=r[5], + topic=r[2], + room_type=r[9], + ) + for r in txn + ] if not forwards: results.reverse() return results - ret_val = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: |