summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-26 15:12:28 -0400
committerGitHub <noreply@github.com>2023-10-26 15:12:28 -0400
commit679c691f6f7c4f7901e6d075a645a8ade20f44d5 (patch)
tree2092e672d80d8cbdbf18756b3eeb84dcc76c12ac /synapse/storage/databases/main
parentAdd a new module API to update user presence state. (#16544) (diff)
downloadsynapse-679c691f6f7c4f7901e6d075a645a8ade20f44d5.tar.xz
Remove more usages of cursor_to_dict. (#16551)
Mostly to improve type safety.
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/devices.py3
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py1
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py15
-rw-r--r--synapse/storage/databases/main/media_repository.py19
-rw-r--r--synapse/storage/databases/main/registration.py43
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/search.py52
-rw-r--r--synapse/storage/databases/main/stats.py15
-rw-r--r--synapse/storage/databases/main/stream.py3
-rw-r--r--synapse/storage/databases/main/transactions.py28
-rw-r--r--synapse/storage/databases/main/user_directory.py14
13 files changed, 130 insertions, 76 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
         """
 
         rows = await self.db_pool.execute(
-            "_censor_redactions_fetch", None, sql, before_ts, 100
+            "_censor_redactions_fetch", sql, before_ts, 100
         )
 
         updates = []
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 0b75f6763a..49edbb9e06 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         rows = await self.db_pool.execute(
             "get_all_devices_changed",
-            None,
             sql,
             from_key,
             to_key,
@@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 WHERE from_user_id = ? AND stream_id > ?
             """
             rows = await self.db_pool.execute(
-                "get_users_whose_signatures_changed", None, sql, user_id, from_key
+                "get_users_whose_signatures_changed", sql, user_id, from_key
             )
             return {user for row in rows for user in db_to_json(row[0])}
         else:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f13d776b0d..f70f95eeba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             """
             rows = await self.db_pool.execute(
                 "get_e2e_device_keys_for_federation_query_check",
-                None,
                 sql,
                 now_stream_id,
                 user_id,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c5fce1c82b..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
         # indexes on it.
-        # We need to pass execute a dummy function to handle the txn's result otherwise
-        # it tries to call fetchall() on it and fails because there's no result to fetch.
-        await self.db_pool.execute(
+        await self.db_pool.runInteraction(
             "background_analyze_new_stream_ordering_column",
-            lambda txn: None,
-            "ANALYZE events(stream_ordering2)",
+            lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
         )
 
         await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
 
 from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
 
     async def get_forward_extremities_for_room(
         self, room_id: str
-    ) -> List[Dict[str, Any]]:
-        """Get list of forward extremities for a room."""
+    ) -> List[Tuple[str, int, int, Optional[int]]]:
+        """
+        Get list of forward extremities for a room.
+
+        Returns:
+            A list of tuples of event_id, state_group, depth, and received_ts.
+        """
 
         def get_forward_extremities_for_room_txn(
             txn: LoggingTransaction,
-        ) -> List[Dict[str, Any]]:
+        ) -> List[Tuple[str, int, int, Optional[int]]]:
             sql = """
                 SELECT event_id, state_group, depth, received_ts
                 FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
             """
 
             txn.execute(sql, (room_id,))
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index f82140b2e8..aeb3db596c 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_remote_media_ids(
         self, before_ts: int, include_quarantined_media: bool
-    ) -> List[Dict[str, str]]:
+    ) -> List[Tuple[str, str, str]]:
         """
         Retrieve a list of server name, media ID tuples from the remote media cache.
 
@@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             A list of tuples containing:
                 * The server name of homeserver where the media originates from,
                 * The ID of the media.
+                * The filesystem ID.
+        """
+
+        sql = """
+        SELECT media_origin, media_id, filesystem_id
+        FROM remote_media_cache
+        WHERE last_access_ts < ?
         """
-        sql = (
-            "SELECT media_origin, media_id, filesystem_id"
-            " FROM remote_media_cache"
-            " WHERE last_access_ts < ?"
-        )
 
         if include_quarantined_media is False:
             # Only include media that has not been quarantined
@@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             AND quarantined_by IS NULL
             """
 
-        return await self.db_pool.execute(
-            "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+        return cast(
+            List[Tuple[str, str, str]],
+            await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
         )
 
     async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index b0ef7be155..e09ab21593 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -151,6 +151,22 @@ class ThreepidResult:
     added_at: int
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+    address: str
+    """address of the 3pid"""
+    medium: str
+    """medium of the 3pid"""
+    client_secret: str
+    """a secret provided by the client for this validation session"""
+    session_id: str
+    """ID of the validation session"""
+    last_send_attempt: int
+    """a number serving to dedupe send attempts for this session"""
+    validated_at: Optional[int]
+    """timestamp of when this session was validated if so"""
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         address: Optional[str] = None,
         sid: Optional[str] = None,
         validated: Optional[bool] = True,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[ThreepidValidationSession]:
         """Gets a session_id and last_send_attempt (if available) for a
         combination of validation metadata
 
@@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 perform no filtering
 
         Returns:
-            A dict containing the following:
-                * address - address of the 3pid
-                * medium - medium of the 3pid
-                * client_secret - a secret provided by the client for this validation session
-                * session_id - ID of the validation session
-                * send_attempt - a number serving to dedupe send attempts for this session
-                * validated_at - timestamp of when this session was validated if so
-
-                Otherwise None if a validation session is not found
+            A ThreepidValidationSession or None if a validation session is not found
         """
         if not client_secret:
             raise SynapseError(
@@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         def get_threepid_validation_session_txn(
             txn: LoggingTransaction,
-        ) -> Optional[Dict[str, Any]]:
+        ) -> Optional[ThreepidValidationSession]:
             sql = """
                 SELECT address, session_id, medium, client_secret,
                 last_send_attempt, validated_at
@@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             sql += " LIMIT 1"
 
             txn.execute(sql, list(keyvalues.values()))
-            rows = self.db_pool.cursor_to_dict(txn)
-            if not rows:
+            row = txn.fetchone()
+            if not row:
                 return None
 
-            return rows[0]
+            return ThreepidValidationSession(
+                address=row[0],
+                session_id=row[1],
+                medium=row[2],
+                client_secret=row[3],
+                last_send_attempt=row[4],
+                validated_at=row[5],
+            )
 
         return await self.db_pool.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a1627dffb7..67e149b586 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         like_clause = "%:" + host
 
         rows = await self.db_pool.execute(
-            "is_host_joined", None, sql, membership, room_id, like_clause
+            "is_host_joined", sql, membership, room_id, like_clause
         )
 
         if not rows:
@@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
                 AND forgotten = 0;
         """
 
-        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+        rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
 
         # `count(*)` returns always an integer
         # If any rows still exist it means someone has not forgotten this room yet
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1d69c4a5f0..dbde9130c6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
     Set,
     Tuple,
     Union,
+    cast,
 )
 
 import attr
@@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
-        results = await self.db_pool.execute(
-            "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+        # List of tuples of (rank, room_id, event_id).
+        results = cast(
+            List[Tuple[Union[int, float], str, str]],
+            await self.db_pool.execute("search_msgs", sql, *args),
         )
 
-        results = list(filter(lambda row: row["room_id"] in room_ids, results))
+        results = list(filter(lambda row: row[1] in room_ids, results))
 
         # We set redact_behaviour to block here to prevent redacted events being returned in
         # search results (which is a data leak)
         events = await self.get_events_as_list(  # type: ignore[attr-defined]
-            [r["event_id"] for r in results],
+            [r[2] for r in results],
             redact_behaviour=EventRedactBehaviour.block,
         )
 
@@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = await self.db_pool.execute(
-            "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+        # List of tuples of (room_id, count).
+        count_results = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
         )
 
-        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+        count = sum(row[1] for row in count_results if row[0] in room_ids)
         return {
             "results": [
-                {"event": event_map[r["event_id"]], "rank": r["rank"]}
+                {"event": event_map[r[2]], "rank": r[0]}
                 for r in results
-                if r["event_id"] in event_map
+                if r[2] in event_map
             ],
             "highlights": highlights,
             "count": count,
@@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             search_query = search_term
             sql = """
             SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
-            origin_server_ts, stream_ordering, room_id, event_id
+            room_id, event_id, origin_server_ts, stream_ordering
             FROM event_search
             WHERE vector @@ websearch_to_tsquery('english', ?) AND
             """
@@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
         # mypy expects to append only a `str`, not an `int`
         args.append(limit)
 
-        results = await self.db_pool.execute(
-            "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+        # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+        results = cast(
+            List[Tuple[Union[int, float], str, str, int, int]],
+            await self.db_pool.execute("search_rooms", sql, *args),
         )
 
-        results = list(filter(lambda row: row["room_id"] in room_ids, results))
+        results = list(filter(lambda row: row[1] in room_ids, results))
 
         # We set redact_behaviour to block here to prevent redacted events being returned in
         # search results (which is a data leak)
         events = await self.get_events_as_list(  # type: ignore[attr-defined]
-            [r["event_id"] for r in results],
+            [r[2] for r in results],
             redact_behaviour=EventRedactBehaviour.block,
         )
 
@@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = await self.db_pool.execute(
-            "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+        # List of tuples of (room_id, count).
+        count_results = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
         )
 
-        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+        count = sum(row[1] for row in count_results if row[0] in room_ids)
 
         return {
             "results": [
                 {
-                    "event": event_map[r["event_id"]],
-                    "rank": r["rank"],
-                    "pagination_token": "%s,%s"
-                    % (r["origin_server_ts"], r["stream_ordering"]),
+                    "event": event_map[r[2]],
+                    "rank": r[0],
+                    "pagination_token": "%s,%s" % (r[3], r[4]),
                 }
                 for r in results
-                if r["event_id"] in event_map
+                if r[2] in event_map
             ],
             "highlights": highlights,
             "count": count,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5b2d0ba870..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
         order_by: Optional[str] = UserSortOrder.USER_ID.value,
         direction: Direction = Direction.FORWARDS,
         search_term: Optional[str] = None,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
         """Function to retrieve a paginated list of users and their uploaded local media
         (size and number). This will return a json list of users and the
         total number of users matching the filter criteria.
@@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
             order_by: the sort order of the returned list
             direction: sort ascending or descending
             search_term: a string to filter user names by
+
         Returns:
-            A list of user dicts and an integer representing the total number of
-            users that exist given this query
+            A tuple of:
+                A list of tuples of user information (the user ID, displayname,
+                total number of media, total length of media) and
+
+                An integer representing the total number of users that exist
+                given this query
         """
 
         def get_users_media_usage_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[JsonDict], int]:
+        ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
             filters = []
             args: list = []
 
@@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
 
             args += [limit, start]
             txn.execute(sql, args)
-            users = self.db_pool.cursor_to_dict(txn)
+            users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
 
             return users, count
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 872df6bda1..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
 
         row = await self.db_pool.execute(
-            "get_current_topological_token", None, sql, room_id, room_id, stream_key
+            "get_current_topological_token", sql, room_id, room_id, stream_key
         )
         return row[0][0] if row else 0
 
@@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows = await self.db_pool.execute(
             "get_timeline_gaps",
-            None,
             sql,
             room_id,
             from_token.stream if from_token else 0,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c4a6475060..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         destination: Optional[str] = None,
         order_by: str = DestinationSortOrder.DESTINATION.value,
         direction: Direction = Direction.FORWARDS,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[
+        List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+        int,
+    ]:
         """Function to retrieve a paginated list of destinations.
         This will return a json list of destinations and the
         total number of destinations matching the filter criteria.
@@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             order_by: the sort order of the returned list
             direction: sort ascending or descending
         Returns:
-            A tuple of a list of mappings from destination to information
+            A tuple of a list of tuples of destination information:
+                * destination
+                * retry_last_ts
+                * retry_interval
+                * failure_ts
+                * last_successful_stream_ordering
             and a count of total destinations.
         """
 
         def get_destinations_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[JsonDict], int]:
+        ) -> Tuple[
+            List[
+                Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+            ],
+            int,
+        ]:
             order_by_column = DestinationSortOrder(order_by).value
 
             if direction == Direction.BACKWARDS:
@@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
                 LIMIT ? OFFSET ?
             """
             txn.execute(sql, args + [limit, start])
-            destinations = self.db_pool.cursor_to_dict(txn)
+            destinations = cast(
+                List[
+                    Tuple[
+                        str, Optional[int], Optional[int], Optional[int], Optional[int]
+                    ]
+                ],
+                txn.fetchall(),
+            )
             return destinations, count
 
         return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 23eb92c514..a9f5d68b63 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             raise Exception("Unrecognized database engine")
 
         results = cast(
-            List[UserProfile],
-            await self.db_pool.execute(
-                "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
-            ),
+            List[Tuple[str, Optional[str], Optional[str]]],
+            await self.db_pool.execute("search_user_dir", sql, *args),
         )
 
         limited = len(results) > limit
 
-        return {"limited": limited, "results": results[0:limit]}
+        return {
+            "limited": limited,
+            "results": [
+                {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+                for r in results[0:limit]
+            ],
+        }
 
 
 def _filter_text_for_index(text: str) -> str: