summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-31 13:13:28 -0400
committerGitHub <noreply@github.com>2023-10-31 13:13:28 -0400
commitcfb6d38c47711b8dfaf0125353aec88d16708b97 (patch)
tree5376fba887e841c9574b5ee444719560e5c47135 /synapse/storage/databases
parentMerge branch 'release-v1.96' into develop (diff)
downloadsynapse-cfb6d38c47711b8dfaf0125353aec88d16708b97.tar.xz
Remove remaining usage of cursor_to_dict. (#16564)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py52
-rw-r--r--synapse/storage/databases/main/devices.py55
-rw-r--r--synapse/storage/databases/main/media_repository.py48
-rw-r--r--synapse/storage/databases/main/registration.py42
-rw-r--r--synapse/storage/databases/main/room.py82
5 files changed, 210 insertions, 69 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 840d725114..89f4077351 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
 
+import attr
+
 from synapse.api.constants import Direction
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage._base import make_in_list_sql_clause
@@ -28,7 +30,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import get_domain_from_id
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -82,6 +84,25 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPaginateResponse:
+    """This is very similar to UserInfo, but not quite the same."""
+
+    name: str
+    user_type: Optional[str]
+    is_guest: bool
+    admin: bool
+    deactivated: bool
+    shadow_banned: bool
+    displayname: Optional[str]
+    avatar_url: Optional[str]
+    creation_ts: Optional[int]
+    approved: bool
+    erased: bool
+    last_seen_ts: int
+    locked: bool
+
+
 class DataStore(
     EventsBackgroundUpdatesStore,
     ExperimentalFeaturesStore,
@@ -156,7 +177,7 @@ class DataStore(
         approved: bool = True,
         not_user_types: Optional[List[str]] = None,
         locked: bool = False,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[UserPaginateResponse], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
@@ -182,7 +203,7 @@ class DataStore(
 
         def get_users_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[JsonDict], int]:
+        ) -> Tuple[List[UserPaginateResponse], int]:
             filters = []
             args: list = []
 
@@ -282,13 +303,24 @@ class DataStore(
             """
             args += [limit, start]
             txn.execute(sql, args)
-            users = self.db_pool.cursor_to_dict(txn)
-
-            # some of those boolean values are returned as integers when we're on SQLite
-            columns_to_boolify = ["erased"]
-            for user in users:
-                for column in columns_to_boolify:
-                    user[column] = bool(user[column])
+            users = [
+                UserPaginateResponse(
+                    name=row[0],
+                    user_type=row[1],
+                    is_guest=bool(row[2]),
+                    admin=bool(row[3]),
+                    deactivated=bool(row[4]),
+                    shadow_banned=bool(row[5]),
+                    displayname=row[6],
+                    avatar_url=row[7],
+                    creation_ts=row[8],
+                    approved=bool(row[9]),
+                    erased=bool(row[10]),
+                    last_seen_ts=row[11],
+                    locked=bool(row[12]),
+                )
+                for row in txn
+            ]
 
             return users, count
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 49edbb9e06..b0811a4cf1 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         #
         # For each duplicate, we delete all the existing rows and put one back.
 
-        KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
         last_row = progress.get(
             "last_row",
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
-                [(x, last_row[x]) for x in KEY_COLS]
+                [
+                    ("stream_id", last_row["stream_id"]),
+                    ("destination", last_row["destination"]),
+                    ("user_id", last_row["user_id"]),
+                    ("device_id", last_row["device_id"]),
+                ]
             )
-            sql = """
+            sql = f"""
                 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
                 FROM device_lists_outbound_pokes
-                WHERE %s
-                GROUP BY %s
+                WHERE {clause}
+                GROUP BY stream_id, destination, user_id, device_id
                 HAVING count(*) > 1
-                ORDER BY %s
+                ORDER BY stream_id, destination, user_id, device_id
                 LIMIT ?
-                """ % (
-                clause,  # WHERE
-                ",".join(KEY_COLS),  # GROUP BY
-                ",".join(KEY_COLS),  # ORDER BY
-            )
+                """
             txn.execute(sql, args + [batch_size])
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
 
-            row = None
-            for row in rows:
+            stream_id, destination, user_id, device_id = None, None, None, None
+            for stream_id, destination, user_id, device_id, _ in rows:
                 self.db_pool.simple_delete_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    {x: row[x] for x in KEY_COLS},
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
                 )
 
-                row["sent"] = False
                 self.db_pool.simple_insert_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    row,
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "sent": False,
+                    },
                 )
 
-            if row:
+            if rows:
                 self.db_pool.updates._background_update_progress_txn(
                     txn,
                     BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
-                    {"last_row": row},
+                    {
+                        "last_row": {
+                            "stream_id": stream_id,
+                            "destination": destination,
+                            "user_id": user_id,
+                            "device_id": device_id,
+                        }
+                    },
                 )
 
             return len(rows)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..c8d7c9fd32 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,8 @@ from typing import (
     cast,
 )
 
+import attr
+
 from synapse.api.constants import Direction
 from synapse.logging.opentracing import trace
 from synapse.media._base import ThumbnailInfo
@@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
 )
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+    media_id: str
+    media_type: str
+    media_length: int
+    upload_name: str
+    created_ts: int
+    last_access_ts: int
+    quarantined_by: Optional[str]
+    safe_from_quarantine: bool
+
+
 class MediaSortOrder(Enum):
     """
     Enum to define the sorting method used when returning media with
@@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         user_id: str,
         order_by: str = MediaSortOrder.CREATED_TS.value,
         direction: Direction = Direction.FORWARDS,
-    ) -> Tuple[List[Dict[str, Any]], int]:
+    ) -> Tuple[List[LocalMedia], int]:
         """Get a paginated list of metadata for a local piece of media
         which an user_id has uploaded
 
@@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
         def get_local_media_by_user_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[Dict[str, Any]], int]:
+        ) -> Tuple[List[LocalMedia], int]:
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
@@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             sql = """
                 SELECT
-                    "media_id",
-                    "media_type",
-                    "media_length",
-                    "upload_name",
-                    "created_ts",
-                    "last_access_ts",
-                    "quarantined_by",
-                    "safe_from_quarantine"
+                    media_id,
+                    media_type,
+                    media_length,
+                    upload_name,
+                    created_ts,
+                    last_access_ts,
+                    quarantined_by,
+                    safe_from_quarantine
                 FROM local_media_repository
                 WHERE user_id = ?
                 ORDER BY {order_by_column} {order}, media_id ASC
@@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             args += [limit, start]
             txn.execute(sql, args)
-            media = self.db_pool.cursor_to_dict(txn)
+            media = [
+                LocalMedia(
+                    media_id=row[0],
+                    media_type=row[1],
+                    media_length=row[2],
+                    upload_name=row[3],
+                    created_ts=row[4],
+                    last_access_ts=row[5],
+                    quarantined_by=row[6],
+                    safe_from_quarantine=bool(row[7]),
+                )
+                for row in txn
+            ]
             return media, count
 
         return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..933d76e905 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
     async def get_registration_tokens(
         self, valid: Optional[bool] = None
-    ) -> List[Dict[str, Any]]:
+    ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
         """List all registration tokens. Used by the admin API.
 
         Args:
@@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
               Default is None: return all tokens regardless of validity.
 
         Returns:
-            A list of dicts, each containing details of a token.
+            A list of tuples containing:
+                * The token
+                * The number of users allowed (or None)
+                * Whether it is pending
+                * Whether it has been completed
+                * An expiry time (or None if no expiry)
         """
 
         def select_registration_tokens_txn(
             txn: LoggingTransaction, now: int, valid: Optional[bool]
-        ) -> List[Dict[str, Any]]:
+        ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
             if valid is None:
                 # Return all tokens regardless of validity
-                txn.execute("SELECT * FROM registration_tokens")
+                txn.execute(
+                    """
+                    SELECT token, uses_allowed, pending, completed, expiry_time
+                    FROM registration_tokens
+                    """
+                )
 
             elif valid:
                 # Select valid tokens only
-                sql = (
-                    "SELECT * FROM registration_tokens WHERE "
-                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
-                    "AND (expiry_time > ? OR expiry_time IS NULL)"
-                )
+                sql = """
+                SELECT token, uses_allowed, pending, completed, expiry_time
+                FROM registration_tokens
+                WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
+                    AND (expiry_time > ? OR expiry_time IS NULL)
+                """
                 txn.execute(sql, [now])
 
             else:
                 # Select invalid tokens only
-                sql = (
-                    "SELECT * FROM registration_tokens WHERE "
-                    "uses_allowed <= pending + completed OR expiry_time <= ?"
-                )
+                sql = """
+                SELECT token, uses_allowed, pending, completed, expiry_time
+                FROM registration_tokens
+                WHERE uses_allowed <= pending + completed OR expiry_time <= ?
+                """
                 txn.execute(sql, [now])
 
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(
+                List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "select_registration_tokens",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..6d4b9891e7 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride:
     burst_count: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LargestRoomStats:
+    room_id: str
+    name: Optional[str]
+    canonical_alias: Optional[str]
+    joined_members: int
+    join_rules: Optional[str]
+    guest_access: Optional[str]
+    history_visibility: Optional[str]
+    state_events: int
+    avatar: Optional[str]
+    topic: Optional[str]
+    room_type: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomStats(LargestRoomStats):
+    joined_local_members: int
+    version: Optional[str]
+    creator: Optional[str]
+    encryption: Optional[str]
+    federatable: bool
+    public: bool
+
+
 class RoomSortOrder(Enum):
     """
     Enum to define the sorting method used when returning rooms with get_rooms_paginate
@@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             allow_none=True,
         )
 
-    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
+    async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
         """Retrieve room with statistics.
 
         Args:
@@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def get_room_with_stats_txn(
             txn: LoggingTransaction, room_id: str
-        ) -> Optional[Dict[str, Any]]:
+        ) -> Optional[RoomStats]:
             sql = """
                 SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 WHERE room_id = ?
                 """
             txn.execute(sql, [room_id])
-            # Catch error if sql returns empty result to return "None" instead of an error
-            try:
-                res = self.db_pool.cursor_to_dict(txn)[0]
-            except IndexError:
+            row = txn.fetchone()
+            if not row:
                 return None
-
-            res["federatable"] = bool(res["federatable"])
-            res["public"] = bool(res["public"])
-            return res
+            return RoomStats(
+                room_id=row[0],
+                name=row[1],
+                canonical_alias=row[2],
+                joined_members=row[3],
+                joined_local_members=row[4],
+                version=row[5],
+                creator=row[6],
+                encryption=row[7],
+                federatable=bool(row[8]),
+                public=bool(row[9]),
+                join_rules=row[10],
+                guest_access=row[11],
+                history_visibility=row[12],
+                state_events=row[13],
+                avatar=row[14],
+                topic=row[15],
+                room_type=row[16],
+            )
 
         return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
@@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         bounds: Optional[Tuple[int, str]],
         forwards: bool,
         ignore_non_federatable: bool = False,
-    ) -> List[Dict[str, Any]]:
+    ) -> List[LargestRoomStats]:
         """Gets the largest public rooms (where largest is in terms of joined
         members, as tracked in the statistics table).
 
@@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def _get_largest_public_rooms_txn(
             txn: LoggingTransaction,
-        ) -> List[Dict[str, Any]]:
+        ) -> List[LargestRoomStats]:
             txn.execute(sql, query_args)
 
-            results = self.db_pool.cursor_to_dict(txn)
+            results = [
+                LargestRoomStats(
+                    room_id=r[0],
+                    name=r[1],
+                    canonical_alias=r[3],
+                    joined_members=r[4],
+                    join_rules=r[8],
+                    guest_access=r[7],
+                    history_visibility=r[6],
+                    state_events=0,
+                    avatar=r[5],
+                    topic=r[2],
+                    room_type=r[9],
+                )
+                for r in txn
+            ]
 
             if not forwards:
                 results.reverse()
 
             return results
 
-        ret_val = await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
-        return ret_val
 
     @cached(max_entries=10000)
     async def is_room_blocked(self, room_id: str) -> Optional[bool]: