summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-11 13:24:56 -0400
committerGitHub <noreply@github.com>2023-10-11 13:24:56 -0400
commita4904dcb04b31ce8ed0deaa2c5c80657780f6618 (patch)
tree179aedc3390ce9cafcd5f3d78a20644ab8d3dd87 /synapse/storage/databases
parentHandle content types with parameters. (#16440) (diff)
downloadsynapse-a4904dcb04b31ce8ed0deaa2c5c80657780f6618.tar.xz
Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py42
-rw-r--r--synapse/storage/databases/main/devices.py49
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py19
-rw-r--r--synapse/storage/databases/main/event_federation.py107
-rw-r--r--synapse/storage/databases/main/events.py79
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py62
-rw-r--r--synapse/storage/databases/main/events_worker.py36
-rw-r--r--synapse/storage/databases/main/keys.py46
-rw-r--r--synapse/storage/databases/main/presence.py51
-rw-r--r--synapse/storage/databases/main/push_rule.py97
-rw-r--r--synapse/storage/databases/main/relations.py19
-rw-r--r--synapse/storage/databases/main/room.py19
-rw-r--r--synapse/storage/databases/main/roommember.py78
-rw-r--r--synapse/storage/databases/main/state.py62
-rw-r--r--synapse/storage/databases/main/stats.py37
-rw-r--r--synapse/storage/databases/main/transactions.py28
-rw-r--r--synapse/storage/databases/main/ui_auth.py41
-rw-r--r--synapse/storage/databases/main/user_directory.py54
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py19
-rw-r--r--synapse/storage/databases/state/store.py54
20 files changed, 589 insertions, 410 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 744e98c6d0..1cf649d371 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             # Note that this is more efficient than just dropping `device_id` from the query,
             # since device_inbox has an index on `(user_id, device_id, stream_id)`
             if not device_ids_to_query:
-                user_device_dicts = self.db_pool.simple_select_many_txn(
-                    txn,
-                    table="devices",
-                    column="user_id",
-                    iterable=user_ids_to_query,
-                    keyvalues={"hidden": False},
-                    retcols=("device_id",),
+                user_device_dicts = cast(
+                    List[Tuple[str]],
+                    self.db_pool.simple_select_many_txn(
+                        txn,
+                        table="devices",
+                        column="user_id",
+                        iterable=user_ids_to_query,
+                        keyvalues={"hidden": False},
+                        retcols=("device_id",),
+                    ),
                 )
 
-                device_ids_to_query.update(
-                    {row["device_id"] for row in user_device_dicts}
-                )
+                device_ids_to_query.update({row[0] for row in user_device_dicts})
 
             if not device_ids_to_query:
                 # We've ended up with no devices to query.
@@ -845,20 +846,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
                 # We exclude hidden devices (such as cross-signing keys) here as they are
                 # not expected to receive to-device messages.
-                rows = self.db_pool.simple_select_many_txn(
-                    txn,
-                    table="devices",
-                    keyvalues={"user_id": user_id, "hidden": False},
-                    column="device_id",
-                    iterable=devices,
-                    retcols=("device_id",),
+                rows = cast(
+                    List[Tuple[str]],
+                    self.db_pool.simple_select_many_txn(
+                        txn,
+                        table="devices",
+                        keyvalues={"user_id": user_id, "hidden": False},
+                        column="device_id",
+                        iterable=devices,
+                        retcols=("device_id",),
+                    ),
                 )
 
-                for row in rows:
+                for (device_id,) in rows:
                     # Only insert into the local inbox if the device exists on
                     # this server
-                    device_id = row["device_id"]
-
                     with start_active_span("serialise_to_device_message"):
                         msg = messages_by_device[device_id]
                         set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9f3804a504..fc23d18eba 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     async def get_device_list_last_stream_id_for_remotes(
         self, user_ids: Iterable[str]
     ) -> Mapping[str, Optional[str]]:
-        rows = await self.db_pool.simple_select_many_batch(
-            table="device_lists_remote_extremeties",
-            column="user_id",
-            iterable=user_ids,
-            retcols=("user_id", "stream_id"),
-            desc="get_device_list_last_stream_id_for_remotes",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="device_lists_remote_extremeties",
+                column="user_id",
+                iterable=user_ids,
+                retcols=("user_id", "stream_id"),
+                desc="get_device_list_last_stream_id_for_remotes",
+            ),
         )
 
         results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
-        results.update({row["user_id"]: row["stream_id"] for row in rows})
+        results.update(rows)
 
         return results
 
@@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The IDs of users whose device lists need resync.
         """
         if user_ids:
-            rows = await self.db_pool.simple_select_many_batch(
-                table="device_lists_remote_resync",
-                column="user_id",
-                iterable=user_ids,
-                retcols=("user_id",),
-                desc="get_user_ids_requiring_device_list_resync_with_iterable",
+            row_tuples = cast(
+                List[Tuple[str]],
+                await self.db_pool.simple_select_many_batch(
+                    table="device_lists_remote_resync",
+                    column="user_id",
+                    iterable=user_ids,
+                    retcols=("user_id",),
+                    desc="get_user_ids_requiring_device_list_resync_with_iterable",
+                ),
             )
+
+            return {row[0] for row in row_tuples}
         else:
-            rows = await self.db_pool.simple_select_list(
-                table="device_lists_remote_resync",
-                keyvalues=None,
-                retcols=("user_id",),
-                desc="get_user_ids_requiring_device_list_resync",
+            rows = cast(
+                List[Dict[str, str]],
+                await self.db_pool.simple_select_list(
+                    table="device_lists_remote_resync",
+                    keyvalues=None,
+                    retcols=("user_id",),
+                    desc="get_user_ids_requiring_device_list_resync",
+                ),
             )
 
-        return {row["user_id"] for row in rows}
+            return {row["user_id"] for row in rows}
 
     async def mark_remote_users_device_caches_as_stale(
         self, user_ids: StrCollection
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 749ae54e20..f13d776b0d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             A map from (algorithm, key_id) to json string for key
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="e2e_one_time_keys_json",
-            column="key_id",
-            iterable=key_ids,
-            retcols=("algorithm", "key_id", "key_json"),
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="add_e2e_one_time_keys_check",
+        rows = cast(
+            List[Tuple[str, str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="e2e_one_time_keys_json",
+                column="key_id",
+                iterable=key_ids,
+                retcols=("algorithm", "key_id", "key_json"),
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                desc="add_e2e_one_time_keys_check",
+            ),
         )
-        result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+        result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
         log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
         return result
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index afffa54985..4f80ce75cc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1049,15 +1049,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         Args:
             event_ids: The event IDs to calculate the max depth of.
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="events",
-            column="event_id",
-            iterable=event_ids,
-            retcols=(
-                "event_id",
-                "depth",
+        rows = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="events",
+                column="event_id",
+                iterable=event_ids,
+                retcols=(
+                    "event_id",
+                    "depth",
+                ),
+                desc="get_max_depth_of",
             ),
-            desc="get_max_depth_of",
         )
 
         if not rows:
@@ -1065,10 +1068,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         else:
             max_depth_event_id = ""
             current_max_depth = 0
-            for row in rows:
-                if row["depth"] > current_max_depth:
-                    max_depth_event_id = row["event_id"]
-                    current_max_depth = row["depth"]
+            for event_id, depth in rows:
+                if depth > current_max_depth:
+                    max_depth_event_id = event_id
+                    current_max_depth = depth
 
             return max_depth_event_id, current_max_depth
 
@@ -1078,15 +1081,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         Args:
             event_ids: The event IDs to calculate the max depth of.
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="events",
-            column="event_id",
-            iterable=event_ids,
-            retcols=(
-                "event_id",
-                "depth",
+        rows = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="events",
+                column="event_id",
+                iterable=event_ids,
+                retcols=(
+                    "event_id",
+                    "depth",
+                ),
+                desc="get_min_depth_of",
             ),
-            desc="get_min_depth_of",
         )
 
         if not rows:
@@ -1094,10 +1100,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         else:
             min_depth_event_id = ""
             current_min_depth = MAX_DEPTH
-            for row in rows:
-                if row["depth"] < current_min_depth:
-                    min_depth_event_id = row["event_id"]
-                    current_min_depth = row["depth"]
+            for event_id, depth in rows:
+                if depth < current_min_depth:
+                    min_depth_event_id = event_id
+                    current_min_depth = depth
 
             return min_depth_event_id, current_min_depth
 
@@ -1553,19 +1559,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             A filtered down list of `event_ids` that have previous failed pull attempts.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="event_failed_pull_attempts",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=("event_id",),
-            desc="get_event_ids_with_failed_pull_attempts",
+        rows = cast(
+            List[Tuple[str]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_failed_pull_attempts",
+                column="event_id",
+                iterable=event_ids,
+                keyvalues={},
+                retcols=("event_id",),
+                desc="get_event_ids_with_failed_pull_attempts",
+            ),
         )
-        event_ids_with_failed_pull_attempts: Set[str] = {
-            row["event_id"] for row in rows
-        }
-
-        return event_ids_with_failed_pull_attempts
+        return {row[0] for row in rows}
 
     @trace
     async def get_event_ids_to_not_pull_from_backoff(
@@ -1585,32 +1590,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             A dictionary of event_ids that should not be attempted to be pulled and the
             next timestamp at which we may try pulling them again.
         """
-        event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
-            table="event_failed_pull_attempts",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=(
-                "event_id",
-                "last_attempt_ts",
-                "num_attempts",
+        event_failed_pull_attempts = cast(
+            List[Tuple[str, int, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_failed_pull_attempts",
+                column="event_id",
+                iterable=event_ids,
+                keyvalues={},
+                retcols=(
+                    "event_id",
+                    "last_attempt_ts",
+                    "num_attempts",
+                ),
+                desc="get_event_ids_to_not_pull_from_backoff",
             ),
-            desc="get_event_ids_to_not_pull_from_backoff",
         )
 
         current_time = self._clock.time_msec()
 
         event_ids_with_backoff = {}
-        for event_failed_pull_attempt in event_failed_pull_attempts:
-            event_id = event_failed_pull_attempt["event_id"]
+        for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
             # Exponential back-off (up to the upper bound) so we don't try to
             # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
             backoff_end_time = (
-                event_failed_pull_attempt["last_attempt_ts"]
+                last_attempt_ts
                 + (
                     2
                     ** min(
-                        event_failed_pull_attempt["num_attempts"],
+                        num_attempts,
                         BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
                     )
                 )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d4dcdb898c..ef6766b5e0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -27,6 +27,7 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    Union,
     cast,
 )
 
@@ -501,16 +502,19 @@ class PersistEventsStore:
 
         # We ignore legacy rooms that we aren't filling the chain cover index
         # for.
-        rows = self.db_pool.simple_select_many_txn(
-            txn,
-            table="rooms",
-            column="room_id",
-            iterable={event.room_id for event in events if event.is_state()},
-            keyvalues={},
-            retcols=("room_id", "has_auth_chain_index"),
+        rows = cast(
+            List[Tuple[str, Optional[Union[int, bool]]]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="rooms",
+                column="room_id",
+                iterable={event.room_id for event in events if event.is_state()},
+                keyvalues={},
+                retcols=("room_id", "has_auth_chain_index"),
+            ),
         )
         rooms_using_chain_index = {
-            row["room_id"] for row in rows if row["has_auth_chain_index"]
+            room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
         }
 
         state_events = {
@@ -571,19 +575,18 @@ class PersistEventsStore:
         # We check if there are any events that need to be handled in the rooms
         # we're looking at. These should just be out of band memberships, where
         # we didn't have the auth chain when we first persisted.
-        rows = db_pool.simple_select_many_txn(
-            txn,
-            table="event_auth_chain_to_calculate",
-            keyvalues={},
-            column="room_id",
-            iterable=set(event_to_room_id.values()),
-            retcols=("event_id", "type", "state_key"),
+        auth_chain_to_calc_rows = cast(
+            List[Tuple[str, str, str]],
+            db_pool.simple_select_many_txn(
+                txn,
+                table="event_auth_chain_to_calculate",
+                keyvalues={},
+                column="room_id",
+                iterable=set(event_to_room_id.values()),
+                retcols=("event_id", "type", "state_key"),
+            ),
         )
-        for row in rows:
-            event_id = row["event_id"]
-            event_type = row["type"]
-            state_key = row["state_key"]
-
+        for event_id, event_type, state_key in auth_chain_to_calc_rows:
             # (We could pull out the auth events for all rows at once using
             # simple_select_many, but this case happens rarely and almost always
             # with a single row.)
@@ -753,23 +756,31 @@ class PersistEventsStore:
         # Step 1, fetch all existing links from all the chains we've seen
         # referenced.
         chain_links = _LinkMap()
-        rows = db_pool.simple_select_many_txn(
-            txn,
-            table="event_auth_chain_links",
-            column="origin_chain_id",
-            iterable={chain_id for chain_id, _ in chain_map.values()},
-            keyvalues={},
-            retcols=(
-                "origin_chain_id",
-                "origin_sequence_number",
-                "target_chain_id",
-                "target_sequence_number",
+        auth_chain_rows = cast(
+            List[Tuple[int, int, int, int]],
+            db_pool.simple_select_many_txn(
+                txn,
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable={chain_id for chain_id, _ in chain_map.values()},
+                keyvalues={},
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
             ),
         )
-        for row in rows:
+        for (
+            origin_chain_id,
+            origin_sequence_number,
+            target_chain_id,
+            target_sequence_number,
+        ) in auth_chain_rows:
             chain_links.add_link(
-                (row["origin_chain_id"], row["origin_sequence_number"]),
-                (row["target_chain_id"], row["target_sequence_number"]),
+                (origin_chain_id, origin_sequence_number),
+                (target_chain_id, target_sequence_number),
                 new=False,
             )
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index daef3685b0..c5fce1c82b 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
-                ev_rows = self.db_pool.simple_select_many_txn(
-                    txn,
-                    table="event_json",
-                    column="event_id",
-                    iterable=chunk,
-                    retcols=["event_id", "json"],
-                    keyvalues={},
+                ev_rows = cast(
+                    List[Tuple[str, str]],
+                    self.db_pool.simple_select_many_txn(
+                        txn,
+                        table="event_json",
+                        column="event_id",
+                        iterable=chunk,
+                        retcols=["event_id", "json"],
+                        keyvalues={},
+                    ),
                 )
 
-                for row in ev_rows:
-                    event_id = row["event_id"]
-                    event_json = db_to_json(row["json"])
+                for event_id, json in ev_rows:
+                    event_json = db_to_json(json)
                     try:
                         origin_server_ts = event_json["origin_server_ts"]
                     except (KeyError, AttributeError):
@@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             if deleted:
                 # We now need to invalidate the caches of these rooms
-                rows = self.db_pool.simple_select_many_txn(
-                    txn,
-                    table="events",
-                    column="event_id",
-                    iterable=to_delete,
-                    keyvalues={},
-                    retcols=("room_id",),
+                rows = cast(
+                    List[Tuple[str]],
+                    self.db_pool.simple_select_many_txn(
+                        txn,
+                        table="events",
+                        column="event_id",
+                        iterable=to_delete,
+                        keyvalues={},
+                        retcols=("room_id",),
+                    ),
                 )
-                room_ids = {row["room_id"] for row in rows}
+                room_ids = {row[0] for row in rows}
                 for room_id in room_ids:
                     txn.call_after(
                         self.get_latest_event_ids_in_room.invalidate, (room_id,)  # type: ignore[attr-defined]
@@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         count = len(rows)
 
         # We also need to fetch the auth events for them.
-        auth_events = self.db_pool.simple_select_many_txn(
-            txn,
-            table="event_auth",
-            column="event_id",
-            iterable=event_to_room_id,
-            keyvalues={},
-            retcols=("event_id", "auth_id"),
+        auth_events = cast(
+            List[Tuple[str, str]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="event_auth",
+                column="event_id",
+                iterable=event_to_room_id,
+                keyvalues={},
+                retcols=("event_id", "auth_id"),
+            ),
         )
 
         event_to_auth_chain: Dict[str, List[str]] = {}
-        for row in auth_events:
-            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+        for event_id, auth_id in auth_events:
+            event_to_auth_chain.setdefault(event_id, []).append(auth_id)
 
         # Calculate and persist the chain cover index for this set of events.
         #
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b788d70fc5..8af638d60f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="events",
-            retcols=("event_id",),
-            column="event_id",
-            iterable=list(event_ids),
-            keyvalues={"outlier": False},
-            desc="have_events_in_timeline",
+        rows = cast(
+            List[Tuple[str]],
+            await self.db_pool.simple_select_many_batch(
+                table="events",
+                retcols=("event_id",),
+                column="event_id",
+                iterable=list(event_ids),
+                keyvalues={"outlier": False},
+                desc="have_events_in_timeline",
+            ),
         )
 
-        return {r["event_id"] for r in rows}
+        return {r[0] for r in rows}
 
     @trace
     @tag_args
@@ -2336,15 +2339,18 @@ class EventsWorkerStore(SQLBaseStore):
             a dict mapping from event id to partial-stateness. We return True for
             any of the events which are unknown (or are outliers).
         """
-        result = await self.db_pool.simple_select_many_batch(
-            table="partial_state_events",
-            column="event_id",
-            iterable=event_ids,
-            retcols=["event_id"],
-            desc="get_partial_state_events",
+        result = cast(
+            List[Tuple[str]],
+            await self.db_pool.simple_select_many_batch(
+                table="partial_state_events",
+                column="event_id",
+                iterable=event_ids,
+                retcols=["event_id"],
+                desc="get_partial_state_events",
+            ),
         )
         # convert the result to a dict, to make @cachedList work
-        partial = {r["event_id"] for r in result}
+        partial = {r[0] for r in result}
         return {e_id: e_id in partial for e_id in event_ids}
 
     @cached()
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 889c578b9c..ea797864b9 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
 import itertools
 import json
 import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore):
 
         If we have multiple entries for a given key ID, returns the most recent.
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="server_keys_json",
-            column="key_id",
-            iterable=key_ids,
-            keyvalues={"server_name": server_name},
-            retcols=(
-                "key_id",
-                "from_server",
-                "ts_added_ms",
-                "ts_valid_until_ms",
-                "key_json",
+        rows = cast(
+            List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+            await self.db_pool.simple_select_many_batch(
+                table="server_keys_json",
+                column="key_id",
+                iterable=key_ids,
+                keyvalues={"server_name": server_name},
+                retcols=(
+                    "key_id",
+                    "from_server",
+                    "ts_added_ms",
+                    "ts_valid_until_ms",
+                    "key_json",
+                ),
+                desc="get_server_keys_json_for_remote",
             ),
-            desc="get_server_keys_json_for_remote",
         )
 
         if not rows:
             return {}
 
-        # We sort the rows so that the most recently added entry is picked up.
-        rows.sort(key=lambda r: r["ts_added_ms"])
+        # We sort the rows by ts_added_ms so that the most recently added entry
+        # will stomp over older entries in the dictionary.
+        rows.sort(key=lambda r: r[2])
 
         return {
-            row["key_id"]: FetchKeyResultForRemote(
+            key_id: FetchKeyResultForRemote(
                 # Cast to bytes since postgresql returns a memoryview.
-                key_json=bytes(row["key_json"]),
-                valid_until_ts=row["ts_valid_until_ms"],
-                added_ts=row["ts_added_ms"],
+                key_json=bytes(key_json),
+                valid_until_ts=ts_valid_until_ms,
+                added_ts=ts_added_ms,
             )
-            for row in rows
+            for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
         }
 
     async def get_all_server_keys_json_for_remote(
@@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore):
         if not rows:
             return {}
 
+        # We sort the rows by ts_added_ms so that the most recently added entry
+        # will stomp over older entries in the dictionary.
         rows.sort(key=lambda r: r["ts_added_ms"])
 
         return {
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 519f05fb60..3b444d2d07 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -261,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
     async def get_presence_for_users(
         self, user_ids: Iterable[str]
     ) -> Mapping[str, UserPresenceState]:
-        rows = await self.db_pool.simple_select_many_batch(
-            table="presence_stream",
-            column="user_id",
-            iterable=user_ids,
-            keyvalues={},
-            retcols=(
-                "user_id",
-                "state",
-                "last_active_ts",
-                "last_federation_update_ts",
-                "last_user_sync_ts",
-                "status_msg",
-                "currently_active",
+        # TODO All these columns are nullable, but we don't expect that:
+        #      https://github.com/matrix-org/synapse/issues/16467
+        rows = cast(
+            List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+            await self.db_pool.simple_select_many_batch(
+                table="presence_stream",
+                column="user_id",
+                iterable=user_ids,
+                keyvalues={},
+                retcols=(
+                    "user_id",
+                    "state",
+                    "last_active_ts",
+                    "last_federation_update_ts",
+                    "last_user_sync_ts",
+                    "status_msg",
+                    "currently_active",
+                ),
+                desc="get_presence_for_users",
             ),
-            desc="get_presence_for_users",
         )
 
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return {row["user_id"]: UserPresenceState(**row) for row in rows}
+        return {
+            user_id: UserPresenceState(
+                user_id=user_id,
+                state=state,
+                last_active_ts=last_active_ts,
+                last_federation_update_ts=last_federation_update_ts,
+                last_user_sync_ts=last_user_sync_ts,
+                status_msg=status_msg,
+                currently_active=bool(currently_active),
+            )
+            for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+        }
 
     async def should_user_receive_full_presence_with_token(
         self,
@@ -386,6 +399,8 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
         limit = 100
         offset = 0
         while True:
+            # TODO All these columns are nullable, but we don't expect that:
+            #      https://github.com/matrix-org/synapse/issues/16467
             rows = cast(
                 List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
                 await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 923166974c..f5356e7f80 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -62,20 +62,34 @@ logger = logging.getLogger(__name__)
 
 
 def _load_rules(
-    rawrules: List[JsonDict],
+    rawrules: List[Tuple[str, int, str, str]],
     enabled_map: Dict[str, bool],
     experimental_config: ExperimentalConfig,
 ) -> FilteredPushRules:
     """Take the DB rows returned from the DB and convert them into a full
     `FilteredPushRules` object.
+
+    Args:
+        rawrules: List of tuples of:
+            * rule ID
+            * Priority lass
+            * Conditions (as serialized JSON)
+            * Actions (as serialized JSON)
+        enabled_map: A dictionary of rule ID to a boolean of whether the rule is
+            enabled. This might not include all rule IDs from rawrules.
+        experimental_config: The `experimental_features` section of the Synapse
+            config. (Used to check if various features are enabled.)
+
+    Returns:
+        A new FilteredPushRules object.
     """
 
     ruleslist = [
         PushRule.from_db(
-            rule_id=rawrule["rule_id"],
-            priority_class=rawrule["priority_class"],
-            conditions=rawrule["conditions"],
-            actions=rawrule["actions"],
+            rule_id=rawrule[0],
+            priority_class=rawrule[1],
+            conditions=rawrule[2],
+            actions=rawrule[3],
         )
         for rawrule in rawrules
     ]
@@ -183,7 +197,19 @@ class PushRulesWorkerStore(
 
         enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 
-        return _load_rules(rows, enabled_map, self.hs.config.experimental)
+        return _load_rules(
+            [
+                (
+                    row["rule_id"],
+                    row["priority_class"],
+                    row["conditions"],
+                    row["actions"],
+                )
+                for row in rows
+            ],
+            enabled_map,
+            self.hs.config.experimental,
+        )
 
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
@@ -221,21 +247,36 @@ class PushRulesWorkerStore(
         if not user_ids:
             return {}
 
-        raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+        raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
+            user_id: [] for user_id in user_ids
+        }
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="push_rules",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("*",),
-            desc="bulk_get_push_rules",
-            batch_size=1000,
+        rows = cast(
+            List[Tuple[str, str, int, int, str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="push_rules",
+                column="user_name",
+                iterable=user_ids,
+                retcols=(
+                    "user_name",
+                    "rule_id",
+                    "priority_class",
+                    "priority",
+                    "conditions",
+                    "actions",
+                ),
+                desc="bulk_get_push_rules",
+                batch_size=1000,
+            ),
         )
 
-        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+        # Sort by highest priority_class, then highest priority.
+        rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
 
-        for row in rows:
-            raw_rules.setdefault(row["user_name"], []).append(row)
+        for user_name, rule_id, priority_class, _, conditions, actions in rows:
+            raw_rules.setdefault(user_name, []).append(
+                (rule_id, priority_class, conditions, actions)
+            )
 
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
@@ -256,17 +297,19 @@ class PushRulesWorkerStore(
 
         results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="push_rules_enable",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("user_name", "rule_id", "enabled"),
-            desc="bulk_get_push_rules_enabled",
-            batch_size=1000,
+        rows = cast(
+            List[Tuple[str, str, Optional[int]]],
+            await self.db_pool.simple_select_many_batch(
+                table="push_rules_enable",
+                column="user_name",
+                iterable=user_ids,
+                retcols=("user_name", "rule_id", "enabled"),
+                desc="bulk_get_push_rules_enabled",
+                batch_size=1000,
+            ),
         )
-        for row in rows:
-            enabled = bool(row["enabled"])
-            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+        for user_name, rule_id, enabled in rows:
+            results.setdefault(user_name, {})[rule_id] = bool(enabled)
         return results
 
     async def get_all_push_rule_updates(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 9246b418f5..7f40e2c446 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -349,16 +349,19 @@ class RelationsWorkerStore(SQLBaseStore):
         def get_all_relation_ids_for_event_with_types_txn(
             txn: LoggingTransaction,
         ) -> List[str]:
-            rows = self.db_pool.simple_select_many_txn(
-                txn=txn,
-                table="event_relations",
-                column="relation_type",
-                iterable=relation_types,
-                keyvalues={"relates_to_id": event_id},
-                retcols=["event_id"],
+            rows = cast(
+                List[Tuple[str]],
+                self.db_pool.simple_select_many_txn(
+                    txn=txn,
+                    table="event_relations",
+                    column="relation_type",
+                    iterable=relation_types,
+                    keyvalues={"relates_to_id": event_id},
+                    retcols=["event_id"],
+                ),
             )
 
-            return [row["event_id"] for row in rows]
+            return [row[0] for row in rows]
 
         return await self.db_pool.runInteraction(
             desc="get_all_relation_ids_for_event_with_types",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 1d4d99932b..9d24d2c347 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1296,14 +1296,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         complete.
         """
 
-        rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
-            table="partial_state_rooms",
-            column="room_id",
-            iterable=room_ids,
-            retcols=("room_id",),
-            desc="is_partial_state_room_batched",
-        )
-        partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+        rows = cast(
+            List[Tuple[str]],
+            await self.db_pool.simple_select_many_batch(
+                table="partial_state_rooms",
+                column="room_id",
+                iterable=room_ids,
+                retcols=("room_id",),
+                desc="is_partial_state_room_batched",
+            ),
+        )
+        partial_state_rooms = {row[0] for row in rows}
         return {room_id: room_id in partial_state_rooms for room_id in room_ids}
 
     async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bbe08368db..3a87eba430 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
     Set,
     Tuple,
     Union,
+    cast,
 )
 
 import attr
@@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             Map from user_id to set of rooms that is currently in.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="current_state_events",
-            column="state_key",
-            iterable=user_ids,
-            retcols=(
-                "state_key",
-                "room_id",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="current_state_events",
+                column="state_key",
+                iterable=user_ids,
+                retcols=(
+                    "state_key",
+                    "room_id",
+                ),
+                keyvalues={
+                    "type": EventTypes.Member,
+                    "membership": Membership.JOIN,
+                },
+                desc="get_rooms_for_users",
             ),
-            keyvalues={
-                "type": EventTypes.Member,
-                "membership": Membership.JOIN,
-            },
-            desc="get_rooms_for_users",
         )
 
         user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
 
-        for row in rows:
-            user_rooms[row["state_key"]].add(row["room_id"])
+        for state_key, room_id in rows:
+            user_rooms[state_key].add(room_id)
 
         return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
 
@@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             Map from event ID to `user_id`, or None if event is not a join.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=event_ids,
-            retcols=("user_id", "event_id"),
-            keyvalues={"membership": Membership.JOIN},
-            batch_size=1000,
-            desc="_get_user_ids_from_membership_event_ids",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=event_ids,
+                retcols=("event_id", "user_id"),
+                keyvalues={"membership": Membership.JOIN},
+                batch_size=1000,
+                desc="_get_user_ids_from_membership_event_ids",
+            ),
         )
 
-        return {row["event_id"]: row["user_id"] for row in rows}
+        return dict(rows)
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             membership event, otherwise the value is None.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=member_event_ids,
-            retcols=("user_id", "membership", "event_id"),
-            keyvalues={},
-            batch_size=500,
-            desc="get_membership_from_event_ids",
+        rows = cast(
+            List[Tuple[str, str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=member_event_ids,
+                retcols=("user_id", "membership", "event_id"),
+                keyvalues={},
+                batch_size=500,
+                desc="get_membership_from_event_ids",
+            ),
         )
 
         return {
-            row["event_id"]: EventIdMembership(
-                membership=row["membership"], user_id=row["user_id"]
-            )
-            for row in rows
+            event_id: EventIdMembership(membership=membership, user_id=user_id)
+            for user_id, membership, event_id in rows
         }
 
     async def is_local_host_in_room_ignoring_users(
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5eaaff5b68..598025dd91 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -20,10 +20,12 @@ from typing import (
     Collection,
     Dict,
     Iterable,
+    List,
     Mapping,
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
 import attr
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Raises:
              RuntimeError if the state is unknown at any of the given events
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="event_to_state_groups",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=("event_id", "state_group"),
-            desc="_get_state_group_for_events",
+        rows = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_to_state_groups",
+                column="event_id",
+                iterable=event_ids,
+                keyvalues={},
+                retcols=("event_id", "state_group"),
+                desc="_get_state_group_for_events",
+            ),
         )
 
-        res = {row["event_id"]: row["state_group"] for row in rows}
+        res = dict(rows)
         for e in event_ids:
             if e not in res:
                 raise RuntimeError("No state group for unknown or outlier event %s" % e)
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             The subset of state groups that are referenced.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="event_to_state_groups",
-            column="state_group",
-            iterable=state_groups,
-            keyvalues={},
-            retcols=("DISTINCT state_group",),
-            desc="get_referenced_state_groups",
+        rows = cast(
+            List[Tuple[int]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_to_state_groups",
+                column="state_group",
+                iterable=state_groups,
+                keyvalues={},
+                retcols=("DISTINCT state_group",),
+                desc="get_referenced_state_groups",
+            ),
         )
 
-        return {row["state_group"] for row in rows}
+        return {row[0] for row in rows}
 
     async def update_state_for_partial_state_event(
         self,
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             # potentially stale, since there may have been a period where the
             # server didn't share a room with the remote user and therefore may
             # have missed any device updates.
-            rows = self.db_pool.simple_select_many_txn(
-                txn,
-                table="current_state_events",
-                column="room_id",
-                iterable=to_delete,
-                keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
-                retcols=("state_key",),
+            rows = cast(
+                List[Tuple[str]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="current_state_events",
+                    column="room_id",
+                    iterable=to_delete,
+                    keyvalues={
+                        "type": EventTypes.Member,
+                        "membership": Membership.JOIN,
+                    },
+                    retcols=("state_key",),
+                ),
             )
 
-            potentially_left_users = {row["state_key"] for row in rows}
+            potentially_left_users = {row[0] for row in rows}
 
             # Now lets actually delete the rooms from the DB.
             self.db_pool.simple_delete_many_txn(
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9d403919e4..5b2d0ba870 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore):
         ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
             pos = self.get_room_max_stream_ordering()  # type: ignore[attr-defined]
 
-            rows = self.db_pool.simple_select_many_txn(
-                txn,
-                table="current_state_events",
-                column="type",
-                iterable=[
-                    EventTypes.Create,
-                    EventTypes.JoinRules,
-                    EventTypes.RoomHistoryVisibility,
-                    EventTypes.RoomEncryption,
-                    EventTypes.Name,
-                    EventTypes.Topic,
-                    EventTypes.RoomAvatar,
-                    EventTypes.CanonicalAlias,
-                ],
-                keyvalues={"room_id": room_id, "state_key": ""},
-                retcols=["event_id"],
+            rows = cast(
+                List[Tuple[str]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="current_state_events",
+                    column="type",
+                    iterable=[
+                        EventTypes.Create,
+                        EventTypes.JoinRules,
+                        EventTypes.RoomHistoryVisibility,
+                        EventTypes.RoomEncryption,
+                        EventTypes.Name,
+                        EventTypes.Topic,
+                        EventTypes.RoomAvatar,
+                        EventTypes.CanonicalAlias,
+                    ],
+                    keyvalues={"room_id": room_id, "state_key": ""},
+                    retcols=["event_id"],
+                ),
             )
 
-            event_ids = cast(List[str], [row["event_id"] for row in rows])
+            event_ids = [row[0] for row in rows]
 
             txn.execute(
                 """
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index f35757280d..c4a6475060 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
     async def get_destination_retry_timings_batch(
         self, destinations: StrCollection
     ) -> Mapping[str, Optional[DestinationRetryTimings]]:
-        rows = await self.db_pool.simple_select_many_batch(
-            table="destinations",
-            iterable=destinations,
-            column="destination",
-            retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
-            desc="get_destination_retry_timings_batch",
+        rows = cast(
+            List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
+            await self.db_pool.simple_select_many_batch(
+                table="destinations",
+                iterable=destinations,
+                column="destination",
+                retcols=(
+                    "destination",
+                    "failure_ts",
+                    "retry_last_ts",
+                    "retry_interval",
+                ),
+                desc="get_destination_retry_timings_batch",
+            ),
         )
 
         return {
-            row.pop("destination"): DestinationRetryTimings(**row)
-            for row in rows
-            if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
+            destination: DestinationRetryTimings(
+                failure_ts, retry_last_ts, retry_interval
+            )
+            for destination, failure_ts, retry_last_ts, retry_interval in rows
+            if retry_last_ts and failure_ts and retry_interval
         }
 
     async def set_destination_retry_timings(
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index f38bedbbcd..919c66f553 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -337,13 +337,16 @@ class UIAuthWorkerStore(SQLBaseStore):
 
         # If a registration token was used, decrement the pending counter
         # before deleting the session.
-        rows = self.db_pool.simple_select_many_txn(
-            txn,
-            table="ui_auth_sessions_credentials",
-            column="session_id",
-            iterable=session_ids,
-            keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
-            retcols=["result"],
+        rows = cast(
+            List[Tuple[str]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="ui_auth_sessions_credentials",
+                column="session_id",
+                iterable=session_ids,
+                keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+                retcols=["result"],
+            ),
         )
 
         # Get the tokens used and how much pending needs to be decremented by.
@@ -353,23 +356,25 @@ class UIAuthWorkerStore(SQLBaseStore):
             # registration token stage for that session will be True.
             # If a token was used to authenticate, but registration was
             # never completed, the result will be the token used.
-            token = db_to_json(r["result"])
+            token = db_to_json(r[0])
             if isinstance(token, str):
                 token_counts[token] = token_counts.get(token, 0) + 1
 
         # Update the `pending` counters.
         if len(token_counts) > 0:
-            token_rows = self.db_pool.simple_select_many_txn(
-                txn,
-                table="registration_tokens",
-                column="token",
-                iterable=list(token_counts.keys()),
-                keyvalues={},
-                retcols=["token", "pending"],
+            token_rows = cast(
+                List[Tuple[str, int]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="registration_tokens",
+                    column="token",
+                    iterable=list(token_counts.keys()),
+                    keyvalues={},
+                    retcols=["token", "pending"],
+                ),
             )
-            for token_row in token_rows:
-                token = token_row["token"]
-                new_pending = token_row["pending"] - token_counts[token]
+            for token, pending in token_rows:
+                new_pending = pending - token_counts[token]
                 self.db_pool.simple_update_one_txn(
                     txn,
                     table="registration_tokens",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f0dc31fee6..23eb92c514 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -410,25 +410,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             )
 
             # Next fetch their profiles. Note that not all users have profiles.
-            profile_rows = self.db_pool.simple_select_many_txn(
-                txn,
-                table="profiles",
-                column="full_user_id",
-                iterable=list(users_to_insert),
-                retcols=(
-                    "full_user_id",
-                    "displayname",
-                    "avatar_url",
+            profile_rows = cast(
+                List[Tuple[str, Optional[str], Optional[str]]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="profiles",
+                    column="full_user_id",
+                    iterable=list(users_to_insert),
+                    retcols=(
+                        "full_user_id",
+                        "displayname",
+                        "avatar_url",
+                    ),
+                    keyvalues={},
                 ),
-                keyvalues={},
             )
             profiles = {
-                row["full_user_id"]: _UserDirProfile(
-                    row["full_user_id"],
-                    row["displayname"],
-                    row["avatar_url"],
-                )
-                for row in profile_rows
+                full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
+                for full_user_id, displayname, avatar_url in profile_rows
             }
 
             profiles_to_insert = [
@@ -517,18 +516,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             and not self.get_if_app_services_interested_in_user(user)  # type: ignore[attr-defined]
         ]
 
-        rows = self.db_pool.simple_select_many_txn(
-            txn,
-            table="users",
-            column="name",
-            iterable=users,
-            keyvalues={
-                "deactivated": 0,
-            },
-            retcols=("name", "user_type"),
+        rows = cast(
+            List[Tuple[str, Optional[str]]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="users",
+                column="name",
+                iterable=users,
+                keyvalues={
+                    "deactivated": 0,
+                },
+                retcols=("name", "user_type"),
+            ),
         )
 
-        return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
+        return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
 
     async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
         """Check if the room is either world_readable or publically joinable"""
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 06fcbe5e54..8bd58c6e3d 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, Mapping
+from typing import Iterable, List, Mapping, Tuple, cast
 
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             for each user, whether the user has requested erasure.
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="erased_users",
-            column="user_id",
-            iterable=user_ids,
-            retcols=("user_id",),
-            desc="are_users_erased",
+        rows = cast(
+            List[Tuple[str]],
+            await self.db_pool.simple_select_many_batch(
+                table="erased_users",
+                column="user_id",
+                iterable=user_ids,
+                retcols=("user_id",),
+                desc="are_users_erased",
+            ),
         )
-        erased_users = {row["user_id"] for row in rows}
+        erased_users = {row[0] for row in rows}
 
         return {u: u in erased_users for u in user_ids}
 
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 6984d11352..09d2a8c5b3 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,7 +13,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 import attr
 
@@ -730,19 +740,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
 
-        rows = self.db_pool.simple_select_many_txn(
-            txn,
-            table="state_group_edges",
-            column="prev_state_group",
-            iterable=state_groups_to_delete,
-            keyvalues={},
-            retcols=("state_group",),
+        rows = cast(
+            List[Tuple[int]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="state_group_edges",
+                column="prev_state_group",
+                iterable=state_groups_to_delete,
+                keyvalues={},
+                retcols=("state_group",),
+            ),
         )
 
         remaining_state_groups = {
-            row["state_group"]
-            for row in rows
-            if row["state_group"] not in state_groups_to_delete
+            state_group
+            for state_group, in rows
+            if state_group not in state_groups_to_delete
         }
 
         logger.info(
@@ -799,16 +812,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             A mapping from state group to previous state group.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="state_group_edges",
-            column="prev_state_group",
-            iterable=state_groups,
-            keyvalues={},
-            retcols=("prev_state_group", "state_group"),
-            desc="get_previous_state_groups",
+        rows = cast(
+            List[Tuple[int, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="state_group_edges",
+                column="prev_state_group",
+                iterable=state_groups,
+                keyvalues={},
+                retcols=("state_group", "prev_state_group"),
+                desc="get_previous_state_groups",
+            ),
         )
 
-        return {row["state_group"]: row["prev_state_group"] for row in rows}
+        return dict(rows)
 
     async def purge_room_state(
         self, room_id: str, state_groups_to_delete: Collection[int]