summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16611.misc1
-rw-r--r--changelog.d/16612.misc1
-rw-r--r--changelog.d/16613.feature1
-rw-r--r--changelog.d/16616.feature1
-rw-r--r--changelog.d/16618.misc1
-rw-r--r--docs/postgres.md2
-rw-r--r--docs/usage/configuration/config_documentation.md6
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py3
-rw-r--r--synapse/handlers/profile.py15
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/room_member.py3
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/media/media_repository.py70
-rw-r--r--synapse/media/url_previewer.py11
-rw-r--r--synapse/module_api/__init__.py3
-rw-r--r--synapse/rest/admin/rooms.py8
-rw-r--r--synapse/rest/client/directory.py2
-rw-r--r--synapse/rest/media/thumbnail_resource.py16
-rw-r--r--synapse/storage/database.py10
-rw-r--r--synapse/storage/databases/main/account_data.py24
-rw-r--r--synapse/storage/databases/main/cache.py75
-rw-r--r--synapse/storage/databases/main/devices.py43
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py31
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py30
-rw-r--r--synapse/storage/databases/main/event_federation.py24
-rw-r--r--synapse/storage/databases/main/events.py3
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py15
-rw-r--r--synapse/storage/databases/main/events_worker.py2
-rw-r--r--synapse/storage/databases/main/keys.py17
-rw-r--r--synapse/storage/databases/main/media_repository.py121
-rw-r--r--synapse/storage/databases/main/presence.py9
-rw-r--r--synapse/storage/databases/main/profile.py28
-rw-r--r--synapse/storage/databases/main/purge_events.py31
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/receipts.py4
-rw-r--r--synapse/storage/databases/main/registration.py107
-rw-r--r--synapse/storage/databases/main/room.py45
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/stream.py30
-rw-r--r--synapse/storage/databases/main/task_scheduler.py48
-rw-r--r--synapse/storage/databases/main/transactions.py20
-rw-r--r--synapse/storage/databases/main/ui_auth.py31
-rw-r--r--synapse/storage/databases/main/user_directory.py27
-rw-r--r--synapse/storage/util/id_generators.py56
-rw-r--r--tests/handlers/test_stats.py4
-rw-r--r--tests/handlers/test_user_directory.py4
-rw-r--r--tests/media/test_media_storage.py2
-rw-r--r--tests/rest/admin/test_media.py16
-rw-r--r--tests/rest/admin/test_user.py2
-rw-r--r--tests/rest/client/test_account.py8
-rw-r--r--tests/rest/client/test_register.py12
-rw-r--r--tests/rest/media/test_media_retention.py20
-rw-r--r--tests/storage/databases/main/test_cache.py117
-rw-r--r--tests/storage/test_base.py4
-rw-r--r--tests/storage/test_room.py13
-rw-r--r--tests/utils.py2
56 files changed, 749 insertions, 460 deletions
diff --git a/changelog.d/16611.misc b/changelog.d/16611.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16611.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/changelog.d/16612.misc b/changelog.d/16612.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16612.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/changelog.d/16613.feature b/changelog.d/16613.feature
new file mode 100644
index 0000000000..419c56fb83
--- /dev/null
+++ b/changelog.d/16613.feature
@@ -0,0 +1 @@
+Improve the performance of some operations in multi-worker deployments.
diff --git a/changelog.d/16616.feature b/changelog.d/16616.feature
new file mode 100644
index 0000000000..419c56fb83
--- /dev/null
+++ b/changelog.d/16616.feature
@@ -0,0 +1 @@
+Improve the performance of some operations in multi-worker deployments.
diff --git a/changelog.d/16618.misc b/changelog.d/16618.misc
new file mode 100644
index 0000000000..c026e6b995
--- /dev/null
+++ b/changelog.d/16618.misc
@@ -0,0 +1 @@
+Use `dbname` instead of the deprecated `database` connection parameter for psycopg2.
diff --git a/docs/postgres.md b/docs/postgres.md
index 02d4b9b162..ad7c6a0738 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -66,7 +66,7 @@ database:
   args:
     user: <user>
     password: <pass>
-    database: <db>
+    dbname: <db>
     host: <host>
     cp_min: 5
     cp_max: 10
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index a1ca5fa98c..a673975e04 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -1447,7 +1447,7 @@ database:
   args:
     user: synapse_user
     password: secretpassword
-    database: synapse
+    dbname: synapse
     host: localhost
     port: 5432
     cp_min: 5
@@ -1526,7 +1526,7 @@ databases:
     args:
       user: synapse_user
       password: secretpassword
-      database: synapse_main
+      dbname: synapse_main
       host: localhost
       port: 5432
       cp_min: 5
@@ -1539,7 +1539,7 @@ databases:
     args:
       user: synapse_user
       password: secretpassword
-      database: synapse_state
+      dbname: synapse_state
       host: localhost
       port: 5432
       cp_min: 5
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 8364ac8bc9..bb86419028 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -348,8 +348,7 @@ class Porter:
                     backward_chunk = 0
                     already_ported = 0
             else:
-                forward_chunk = row["forward_rowid"]
-                backward_chunk = row["backward_rowid"]
+                forward_chunk, backward_chunk = row
 
             if total_to_port is None:
                 already_ported, total_to_port = await self._get_total_count_to_port(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union
 
 from synapse.api.errors import (
     AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
             server_name = host
 
         if self._is_mine_server_name(server_name):
-            media_info = await self.store.get_local_media(media_id)
+            media_info: Optional[
+                Union[LocalMedia, RemoteMedia]
+            ] = await self.store.get_local_media(media_id)
         else:
             media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
@@ -322,12 +325,12 @@ class ProfileHandler:
 
         if self.max_avatar_size:
             # Ensure avatar does not exceed max allowed avatar size
-            if media_info["media_length"] > self.max_avatar_size:
+            if media_info.media_length > self.max_avatar_size:
                 logger.warning(
                     "Forbidding avatar change to %s: %d bytes is above the allowed size "
                     "limit",
                     mxc,
-                    media_info["media_length"],
+                    media_info.media_length,
                 )
                 return False
 
@@ -335,12 +338,12 @@ class ProfileHandler:
             # Ensure the avatar's file type is allowed
             if (
                 self.allowed_avatar_mimetypes
-                and media_info["media_type"] not in self.allowed_avatar_mimetypes
+                and media_info.media_type not in self.allowed_avatar_mimetypes
             ):
                 logger.warning(
                     "Forbidding avatar change to %s: mimetype %s not allowed",
                     mxc,
-                    media_info["media_type"],
+                    media_info.media_type,
                 )
                 return False
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6d680b0795..afd8138caf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -269,7 +269,7 @@ class RoomCreationHandler:
         self,
         requester: Requester,
         old_room_id: str,
-        old_room: Dict[str, Any],
+        old_room: Tuple[bool, str, bool],
         new_room_id: str,
         new_version: RoomVersion,
         tombstone_event: EventBase,
@@ -279,7 +279,7 @@ class RoomCreationHandler:
         Args:
             requester: the user requesting the upgrade
             old_room_id: the id of the room to be replaced
-            old_room: a dict containing room information for the room to be replaced,
+            old_room: a tuple containing room information for the room to be replaced,
                 as returned by `RoomWorkerStore.get_room`.
             new_room_id: the id of the replacement room
             new_version: the version to upgrade the room to
@@ -299,7 +299,7 @@ class RoomCreationHandler:
         await self.store.store_room(
             room_id=new_room_id,
             room_creator_user_id=user_id,
-            is_public=old_room["is_public"],
+            is_public=old_room[0],
             room_version=new_version,
         )
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 918eb203e2..eddc2af9ba 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Add new room to the room directory if the old room was there
         # Remove old room from the room directory
         old_room = await self.store.get_room(old_room_id)
-        if old_room is not None and old_room["is_public"]:
+        # If the old room exists and is public.
+        if old_room is not None and old_room[0]:
             await self.store.set_room_is_public(old_room_id, False)
             await self.store.set_room_is_public(room_id, True)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler:
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
                     media = await self._media_repo.store.get_local_media(media_id)
-                    if media is not None and upload_name == media["upload_name"]:
+                    if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True
 
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
 from io import BytesIO
 from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
+import attr
 from matrix_common.types.mxc_uri import MXCUri
 
 import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
             Resolves once a response has successfully been written to request
         """
         media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info or media_info.quarantined_by:
             respond_404(request)
             return
 
         self.mark_recently_accessed(None, media_id)
 
-        media_type = media_info["media_type"]
+        media_type = media_info.media_type
         if not media_type:
             media_type = "application/octet-stream"
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        url_cache = media_info["url_cache"]
+        media_length = media_info.media_length
+        upload_name = name if name else media_info.upload_name
+        url_cache = media_info.url_cache
 
         file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
@@ -310,16 +312,20 @@ class MediaRepository:
 
         # We deliberately stream the file outside the lock
         if responder:
-            media_type = media_info["media_type"]
-            media_length = media_info["media_length"]
-            upload_name = name if name else media_info["upload_name"]
+            upload_name = name if name else media_info.upload_name
             await respond_with_responder(
-                request, responder, media_type, media_length, upload_name
+                request,
+                responder,
+                media_info.media_type,
+                media_info.media_length,
+                upload_name,
             )
         else:
             respond_404(request)
 
-    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+    async def get_remote_media_info(
+        self, server_name: str, media_id: str
+    ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
@@ -353,7 +359,7 @@ class MediaRepository:
 
     async def _get_remote_media_impl(
         self, server_name: str, media_id: str
-    ) -> Tuple[Optional[Responder], dict]:
+    ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -373,15 +379,17 @@ class MediaRepository:
 
         # If we have an entry in the DB, try and look for it
         if media_info:
-            file_id = media_info["filesystem_id"]
+            file_id = media_info.filesystem_id
             file_info = FileInfo(server_name, file_id)
 
-            if media_info["quarantined_by"]:
+            if media_info.quarantined_by:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
 
-            if not media_info["media_type"]:
-                media_info["media_type"] = "application/octet-stream"
+            if not media_info.media_type:
+                media_info = attr.evolve(
+                    media_info, media_type="application/octet-stream"
+                )
 
             responder = await self.media_storage.fetch_media(file_info)
             if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
             if not media_info:
                 raise e
 
-        file_id = media_info["filesystem_id"]
-        if not media_info["media_type"]:
-            media_info["media_type"] = "application/octet-stream"
+        file_id = media_info.filesystem_id
+        if not media_info.media_type:
+            media_info = attr.evolve(media_info, media_type="application/octet-stream")
         file_info = FileInfo(server_name, file_id)
 
         # We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
         # otherwise they'll request thumbnails and get a 404 if they're not
         # ready yet.
         await self._generate_thumbnails(
-            server_name, media_id, file_id, media_info["media_type"]
+            server_name, media_id, file_id, media_info.media_type
         )
 
         responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
-    ) -> dict:
+    ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -518,7 +526,7 @@ class MediaRepository:
                 origin=server_name,
                 media_id=media_id,
                 media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
+                time_now_ms=time_now_ms,
                 upload_name=upload_name,
                 media_length=length,
                 filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
 
         logger.info("Stored remote media in file %r", fname)
 
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        return media_info
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
 
     def _get_thumbnail_requirements(
         self, media_type: str
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer:
         cache_result = await self.store.get_url_cache(url, ts)
         if (
             cache_result
-            and cache_result["expires_ts"] > ts
-            and cache_result["response_code"] / 100 == 2
+            and cache_result.expires_ts > ts
+            and cache_result.response_code // 100 == 2
         ):
             # It may be stored as text in the database, not as bytes (such as
             # PostgreSQL). If so, encode it back before handing it on.
-            og = cache_result["og"]
-            if isinstance(og, str):
-                og = og.encode("utf8")
-            return og
+            if isinstance(cache_result.og, str):
+                return cache_result.og.encode("utf8")
+            return cache_result.og
 
         # If this URL can be accessed via an allowed oEmbed, use that instead.
         url_to_download = url
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 755c59274c..812144a128 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1860,7 +1860,8 @@ class PublicRoomListManager:
         if not room:
             return False
 
-        return room.get("is_public", False)
+        # The first item is whether the room is public.
+        return room[0]
 
     async def add_room_to_public_room_list(self, room_id: str) -> None:
         """Publishes a room to the public room list.
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 23a034522c..7e40bea8aa 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        ret = await self.store.get_room(room_id)
-        if not ret:
+        room = await self.store.get_room(room_id)
+        if not room:
             raise NotFoundError("Room not found")
 
         members = await self.store.get_users_in_room(room_id)
@@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        ret = await self.store.get_room(room_id)
-        if not ret:
+        room = await self.store.get_room(room_id)
+        if not room:
             raise NotFoundError("Room not found")
 
         event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 82944ca711..3534c3c259 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
         if room is None:
             raise NotFoundError("Unknown room")
 
-        return 200, {"visibility": "public" if room["is_public"] else "private"}
+        return 200, {"visibility": "public" if room[0] else "private"}
 
     class PutBody(RequestBodyModel):
         visibility: Literal["public", "private"] = "public"
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..efda8b4ab4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
         if not media_info:
             respond_404(request)
             return
-        if media_info["quarantined_by"]:
+        if media_info.quarantined_by:
             logger.info("Media is quarantined")
             respond_404(request)
             return
@@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
             thumbnail_infos,
             media_id,
             media_id,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
             server_name=None,
         )
 
@@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
         if not media_info:
             respond_404(request)
             return
-        if media_info["quarantined_by"]:
+        if media_info.quarantined_by:
             logger.info("Media is quarantined")
             respond_404(request)
             return
@@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
                 file_info = FileInfo(
                     server_name=None,
                     file_id=media_id,
-                    url_cache=media_info["url_cache"],
+                    url_cache=bool(media_info.url_cache),
                     thumbnail=info,
                 )
 
@@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
             desired_height,
             desired_method,
             desired_type,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
         )
 
         if file_path:
@@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
             server_name, media_id
         )
 
-        file_id = media_info["filesystem_id"]
+        file_id = media_info.filesystem_id
 
         for info in thumbnail_infos:
             t_w = info.width == desired_width
@@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
             if t_w and t_h and t_method and t_type:
                 file_info = FileInfo(
                     server_name=server_name,
-                    file_id=media_info["filesystem_id"],
+                    file_id=file_id,
                     thumbnail=info,
                 )
 
@@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
             m_type,
             thumbnail_infos,
             media_id,
-            media_info["filesystem_id"],
+            media_info.filesystem_id,
             url_cache=False,
             server_name=server_name,
         )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 05775425b7..92b3c77b6d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1660,7 +1660,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one",
-    ) -> Dict[str, Any]:
+    ) -> Tuple[Any, ...]:
         ...
 
     @overload
@@ -1671,7 +1671,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one",
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         ...
 
     async def simple_select_one(
@@ -1681,7 +1681,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
@@ -2190,7 +2190,7 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         retcols: Collection[str],
         allow_none: bool = False,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
 
         if keyvalues:
@@ -2208,7 +2208,7 @@ class DatabasePool:
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-        return dict(zip(retcols, row))
+        return row
 
     async def simple_delete_one(
         self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d7482a1f4e..07f9b65af3 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         )
 
         # Invalidate the cache for any ignored users which were added or removed.
-        for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
-            self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.ignored_by,
+            [
+                (ignored_user_id,)
+                for ignored_user_id in (
+                    previously_ignored_users ^ currently_ignored_users
+                )
+            ],
+        )
         self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
 
     async def remove_account_data_for_user(
@@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 )
 
                 # Invalidate the cache for ignored users which were removed.
-                for ignored_user_id in previously_ignored_users:
-                    self._invalidate_cache_and_stream(
-                        txn, self.ignored_by, (ignored_user_id,)
-                    )
+                self._invalidate_cache_and_stream_bulk(
+                    txn,
+                    self.ignored_by,
+                    [
+                        (ignored_user_id,)
+                        for ignored_user_id in previously_ignored_users
+                    ],
+                )
 
                 # Invalidate for this user the cache tracking ignored users.
                 self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 4d0470ffd9..d7232f566b 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
+    def _invalidate_cache_and_stream_bulk(
+        self,
+        txn: LoggingTransaction,
+        cache_func: CachedFunction,
+        key_tuples: Collection[Tuple[Any, ...]],
+    ) -> None:
+        """A bulk version of _invalidate_cache_and_stream.
+
+        Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
+        for each key-tuple over replication.
+
+        This implementation is more efficient than a loop which repeatedly calls the
+        non-bulk version.
+        """
+        if not key_tuples:
+            return
+
+        for keys in key_tuples:
+            txn.call_after(cache_func.invalidate, keys)
+
+        self._send_invalidation_to_replication_bulk(
+            txn, cache_func.__name__, key_tuples
+        )
+
     def _invalidate_all_cache_and_stream(
         self, txn: LoggingTransaction, cache_func: CachedFunction
     ) -> None:
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if isinstance(self.database_engine, PostgresEngine):
             assert self._cache_id_gen is not None
 
-            # get_next() returns a context manager which is designed to wrap
-            # the transaction. However, we want to only get an ID when we want
-            # to use it, here, so we need to call __enter__ manually, and have
-            # __exit__ called after the transaction finishes.
             stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
+    def _send_invalidation_to_replication_bulk(
+        self,
+        txn: LoggingTransaction,
+        cache_name: str,
+        key_tuples: Collection[Tuple[Any, ...]],
+    ) -> None:
+        """Announce the invalidation of multiple (but not all) cache entries.
+
+        This is more efficient than repeated calls to the non-bulk version. It should
+        NOT be used to invalidating the entire cache: use
+        `_send_invalidation_to_replication` with keys=None.
+
+        Note that this does *not* invalidate the cache locally.
+
+        Args:
+            txn
+            cache_name
+            key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            assert self._cache_id_gen is not None
+
+            stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
+            ts = self._clock.time_msec()
+            txn.call_after(self.hs.get_notifier().on_new_replication_data)
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="cache_invalidation_stream_by_instance",
+                keys=(
+                    "stream_id",
+                    "instance_name",
+                    "cache_func",
+                    "keys",
+                    "invalidation_ts",
+                ),
+                values=[
+                    # We convert key_tuples to a list here because psycopg2 serialises
+                    # lists as pq arrrays, but serialises tuples as "composite types".
+                    # (We need an array because the `keys` column has type `[]text`.)
+                    # See:
+                    #     https://www.psycopg.org/docs/usage.html#adapt-list
+                    #     https://www.psycopg.org/docs/usage.html#adapt-tuple
+                    (stream_id, self._instance_name, cache_name, list(key_tuple), ts)
+                    for stream_id, key_tuple in zip(stream_ids, key_tuples)
+                ],
+            )
+
     def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token_for_writer(instance_name)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 04d12a876c..775abbac79 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             A dict containing the device information, or `None` if the device does not
             exist.
         """
-        return await self.db_pool.simple_select_one(
-            table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
-            retcols=("user_id", "device_id", "display_name"),
-            desc="get_device",
-            allow_none=True,
-        )
-
-    async def get_device_opt(
-        self, user_id: str, device_id: str
-    ) -> Optional[Dict[str, Any]]:
-        """Retrieve a device. Only returns devices that are not marked as
-        hidden.
-
-        Args:
-            user_id: The ID of the user which owns the device
-            device_id: The ID of the device to retrieve
-        Returns:
-            A dict containing the device information, or None if the device does not exist.
-        """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
             allow_none=True,
         )
+        if row is None:
+            return None
+        return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
 
     async def get_devices_by_user(
         self, user_id: str
@@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             retcols=["device_id", "device_data"],
             allow_none=True,
         )
-        return (
-            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
-        )
+        return (row[0], json_decoder.decode(row[1])) if row else None
 
     def _store_dehydrated_device_txn(
         self,
@@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         `FALSE` have not been converted.
         """
 
-        row = await self.db_pool.simple_select_one(
-            table="device_lists_changes_converted_stream_position",
-            keyvalues={},
-            retcols=["stream_id", "room_id"],
-            desc="get_device_change_last_converted_pos",
+        return cast(
+            Tuple[int, str],
+            await self.db_pool.simple_select_one(
+                table="device_lists_changes_converted_stream_position",
+                keyvalues={},
+                retcols=["stream_id", "room_id"],
+                desc="get_device_change_last_converted_pos",
+            ),
         )
-        return row["stream_id"], row["room_id"]
 
     async def set_device_change_last_converted_pos(
         self,
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index ad904a26a6..fae23c3407 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
                     # it isn't there.
                     raise StoreError(404, "No backup with that version exists")
 
-            result = self.db_pool.simple_select_one_txn(
-                txn,
-                table="e2e_room_keys_versions",
-                keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
-                retcols=("version", "algorithm", "auth_data", "etag"),
-                allow_none=False,
+            row = cast(
+                Tuple[int, str, str, Optional[int]],
+                self.db_pool.simple_select_one_txn(
+                    txn,
+                    table="e2e_room_keys_versions",
+                    keyvalues={
+                        "user_id": user_id,
+                        "version": this_version,
+                        "deleted": 0,
+                    },
+                    retcols=("version", "algorithm", "auth_data", "etag"),
+                    allow_none=False,
+                ),
             )
-            assert result is not None  # see comment on `simple_select_one_txn`
-            result["auth_data"] = db_to_json(result["auth_data"])
-            result["version"] = str(result["version"])
-            if result["etag"] is None:
-                result["etag"] = 0
-            return result
+            return {
+                "auth_data": db_to_json(row[2]),
+                "version": str(row[0]),
+                "algorithm": row[1],
+                "etag": 0 if row[3] is None else row[3],
+            }
 
         return await self.db_pool.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4f96ac25c7..8cb61eaee3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
             device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
             device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
-
-            if (user_id, device_id) in seen_user_device:
-                continue
             seen_user_device.add((user_id, device_id))
-            self._invalidate_cache_and_stream(
-                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
-            )
+
+        self._invalidate_cache_and_stream_bulk(
+            txn, self.get_e2e_unused_fallback_key_types, seen_user_device
+        )
 
         return results
 
@@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             if row is None:
                 continue
 
-            key_id = row["key_id"]
-            key_json = row["key_json"]
-            used = row["used"]
+            key_id, key_json, used = row
 
             # Mark fallback key as used if not already.
             if not used and mark_as_used:
@@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
         )
 
-        seen_user_device: Set[Tuple[str, str]] = set()
-        for user_id, device_id, _, _, _ in otk_rows:
-            if (user_id, device_id) in seen_user_device:
-                continue
-            seen_user_device.add((user_id, device_id))
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
+        seen_user_device = {
+            (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
+        }
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.count_e2e_one_time_keys,
+            seen_user_device,
+        )
 
         return otk_rows
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index b8bbd1eccd..98556a0523 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
         room = await self.get_room(room_id)  # type: ignore[attr-defined]
-        if room["has_auth_chain_index"]:
+        # If the room has an auth chain index.
+        if room[1]:
             try:
                 return await self.db_pool.runInteraction(
                     "get_auth_chain_ids_chains",
@@ -410,7 +411,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
         room = await self.get_room(room_id)  # type: ignore[attr-defined]
-        if room["has_auth_chain_index"]:
+        # If the room has an auth chain index.
+        if room[1]:
             try:
                 return await self.db_pool.runInteraction(
                     "get_auth_chain_difference_chains",
@@ -1436,24 +1438,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             )
 
             if event_lookup_result is not None:
+                event_type, depth, stream_ordering = event_lookup_result
                 logger.debug(
                     "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
                     room_id,
                     seed_event_id,
-                    event_lookup_result["depth"],
-                    event_lookup_result["stream_ordering"],
-                    event_lookup_result["type"],
+                    depth,
+                    stream_ordering,
+                    event_type,
                 )
 
-                if event_lookup_result["depth"]:
-                    queue.put(
-                        (
-                            -event_lookup_result["depth"],
-                            -event_lookup_result["stream_ordering"],
-                            seed_event_id,
-                            event_lookup_result["type"],
-                        )
-                    )
+                if depth:
+                    queue.put((-depth, -stream_ordering, seed_event_id, event_type))
 
         while not queue.empty() and len(event_id_results) < limit:
             try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7c34bde3e5..5207cc0f4e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1934,8 +1934,7 @@ class PersistEventsStore:
         if row is None:
             return
 
-        redacted_relates_to = row["relates_to_id"]
-        rel_type = row["relation_type"]
+        redacted_relates_to, rel_type = row
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 0061805150..9c46c5d7bd 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 )
 
                 # Iterate the parent IDs and invalidate caches.
-                for parent_id in {r[1] for r in relations_to_insert}:
-                    cache_tuple = (parent_id,)
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
-                    )
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
-                    )
+                cache_tuples = {(r[1],) for r in relations_to_insert}
+                self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
+                    txn, self.get_relations_for_event, cache_tuples  # type: ignore[attr-defined]
+                )
+                self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
+                    txn, self.get_thread_summary, cache_tuples  # type: ignore[attr-defined]
+                )
 
             if results:
                 latest_event_id = results[-1][0]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5bf864c1fb..4e63a16fa2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        return int(res["topological_ordering"]), int(res["stream_ordering"])
+        return int(res[0]), int(res[1])
 
     async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
         """Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ce88772f9e..c700872fdc 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore):
             # invalidate takes a tuple corresponding to the params of
             # _get_server_keys_json. _get_server_keys_json only takes one
             # param, which is itself the 2-tuple (server_name, key_id).
-            for key_id in verify_keys:
-                self._invalidate_cache_and_stream(
-                    txn, self._get_server_keys_json, ((server_name, key_id),)
-                )
-                self._invalidate_cache_and_stream(
-                    txn, self.get_server_key_json_for_remote, (server_name, key_id)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self._get_server_keys_json,
+                [((server_name, key_id),) for key_id in verify_keys],
+            )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self.get_server_key_json_for_remote,
+                [(server_name, key_id) for key_id in verify_keys],
+            )
 
         await self.db_pool.runInteraction(
             "store_server_keys_response", store_server_keys_response_txn
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..3f80a64dc5 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
-    Any,
     Collection,
-    Dict,
     Iterable,
     List,
     Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
     media_length: int
     upload_name: str
     created_ts: int
+    url_cache: Optional[str]
     last_access_ts: int
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+    media_origin: str
+    media_id: str
+    media_type: str
+    media_length: int
+    upload_name: Optional[str]
+    filesystem_id: str
+    created_ts: int
+    last_access_ts: int
+    quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+    response_code: int
+    expires_ts: int
+    og: Union[str, bytes]
+
+
 class MediaSortOrder(Enum):
     """
     Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         super().__init__(database, db_conn, hs)
         self.server_name: str = hs.hostname
 
-    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+    async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
         """Get the metadata for a local piece of media
 
         Returns:
             None if the media_id doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "last_access_ts",
                 "safe_from_quarantine",
             ),
             allow_none=True,
             desc="get_local_media",
         )
+        if row is None:
+            return None
+        return LocalMedia(
+            media_id=media_id,
+            media_type=row[0],
+            media_length=row[1],
+            upload_name=row[2],
+            created_ts=row[3],
+            quarantined_by=row[4],
+            url_cache=row[5],
+            last_access_ts=row[6],
+            safe_from_quarantine=row[7],
+        )
 
     async def get_local_media_by_user_paginate(
         self,
@@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length,
                     upload_name,
                     created_ts,
+                    url_cache,
                     last_access_ts,
                     quarantined_by,
                     safe_from_quarantine
@@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length=row[2],
                     upload_name=row[3],
                     created_ts=row[4],
-                    last_access_ts=row[5],
-                    quarantined_by=row[6],
-                    safe_from_quarantine=bool(row[7]),
+                    url_cache=row[5],
+                    last_access_ts=row[6],
+                    quarantined_by=row[7],
+                    safe_from_quarantine=bool(row[8]),
                 )
                 for row in txn
             ]
@@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+    async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
             # get the most recently cached result (relative to the given ts)
-            sql = (
-                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                " FROM local_media_repository_url_cache"
-                " WHERE url = ? AND download_ts <= ?"
-                " ORDER BY download_ts DESC LIMIT 1"
-            )
+            sql = """
+                SELECT response_code, expires_ts, og
+                FROM local_media_repository_url_cache
+                WHERE url = ? AND download_ts <= ?
+                ORDER BY download_ts DESC LIMIT 1
+            """
             txn.execute(sql, (url, ts))
             row = txn.fetchone()
 
             if not row:
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
-                sql = (
-                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                    " FROM local_media_repository_url_cache"
-                    " WHERE url = ? AND download_ts > ?"
-                    " ORDER BY download_ts ASC LIMIT 1"
-                )
+                sql = """
+                    SELECT response_code, expires_ts, og
+                    FROM local_media_repository_url_cache
+                    WHERE url = ? AND download_ts > ?
+                    ORDER BY download_ts ASC LIMIT 1
+                """
                 txn.execute(sql, (url, ts))
                 row = txn.fetchone()
 
             if not row:
                 return None
 
-            return dict(
-                zip(
-                    (
-                        "response_code",
-                        "etag",
-                        "expires_ts",
-                        "og",
-                        "media_id",
-                        "download_ts",
-                    ),
-                    row,
-                )
-            )
+            return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
 
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
@@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         response_code: int,
         etag: Optional[str],
         expires_ts: int,
-        og: Optional[str],
+        og: str,
         media_id: str,
         download_ts: int,
     ) -> None:
@@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_cached_remote_media(
         self, origin: str, media_id: str
-    ) -> Optional[Dict[str, Any]]:
-        return await self.db_pool.simple_select_one(
+    ) -> Optional[RemoteMedia]:
+        row = await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "upload_name",
                 "created_ts",
                 "filesystem_id",
+                "last_access_ts",
                 "quarantined_by",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
         )
+        if row is None:
+            return row
+        return RemoteMedia(
+            media_origin=origin,
+            media_id=media_id,
+            media_type=row[0],
+            media_length=row[1],
+            upload_name=row[2],
+            created_ts=row[3],
+            filesystem_id=row[4],
+            last_access_ts=row[5],
+            quarantined_by=row[6],
+        )
 
     async def store_cached_remote_media(
         self,
@@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         t_width: int,
         t_height: int,
         t_type: str,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[ThumbnailInfo]:
         """Fetch the thumbnail info of given width, height and type."""
 
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
             keyvalues={
                 "media_origin": origin,
@@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "thumbnail_method",
                 "thumbnail_type",
                 "thumbnail_length",
-                "filesystem_id",
             ),
             allow_none=True,
             desc="get_remote_media_thumbnail",
         )
+        if row is None:
+            return None
+        return ThumbnailInfo(
+            width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+        )
 
     @trace
     async def store_remote_media_thumbnail(
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 3b444d2d07..0198bb09d2 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
                 # for their user ID.
                 value_values=[(presence_stream_id,) for _ in user_ids],
             )
-            for user_id in user_ids:
-                self._invalidate_cache_and_stream(
-                    txn, self._get_full_presence_stream_token_for_user, (user_id,)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self._get_full_presence_stream_token_for_user,
+                [(user_id,) for user_id in user_ids],
+            )
 
         return await self.db_pool.runInteraction(
             "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 3ba9cc8853..7ed111f632 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 from typing import TYPE_CHECKING, Optional
 
-from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
         return 50
 
     async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
-        try:
-            profile = await self.db_pool.simple_select_one(
-                table="profiles",
-                keyvalues={"full_user_id": user_id.to_string()},
-                retcols=("displayname", "avatar_url"),
-                desc="get_profileinfo",
-            )
-        except StoreError as e:
-            if e.code == 404:
-                # no match
-                return ProfileInfo(None, None)
-            else:
-                raise
-
-        return ProfileInfo(
-            avatar_url=profile["avatar_url"], display_name=profile["displayname"]
+        profile = await self.db_pool.simple_select_one(
+            table="profiles",
+            keyvalues={"full_user_id": user_id.to_string()},
+            retcols=("displayname", "avatar_url"),
+            desc="get_profileinfo",
+            allow_none=True,
         )
+        if profile is None:
+            # no match
+            return ProfileInfo(None, None)
+
+        return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
 
     async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 1e11bf2706..c3b3e2baaf 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # so make sure to keep this actually last.
         txn.execute("DROP TABLE events_to_purge")
 
-        for event_id, should_delete in event_rows:
-            self._invalidate_cache_and_stream(
-                txn, self._get_state_group_for_event, (event_id,)
-            )
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self._get_state_group_for_event,
+            [(event_id,) for event_id, _ in event_rows],
+        )
 
-            # XXX: This is racy, since have_seen_events could be called between the
-            #    transaction completing and the invalidation running. On the other hand,
-            #    that's no different to calling `have_seen_events` just before the
-            #    event is deleted from the database.
+        # XXX: This is racy, since have_seen_events could be called between the
+        #    transaction completing and the invalidation running. On the other hand,
+        #    that's no different to calling `have_seen_events` just before the
+        #    event is deleted from the database.
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.have_seen_event,
+            [
+                (room_id, event_id)
+                for event_id, should_delete in event_rows
+                if should_delete
+            ],
+        )
+
+        for event_id, should_delete in event_rows:
             if should_delete:
-                self._invalidate_cache_and_stream(
-                    txn, self.have_seen_event, (room_id, event_id)
-                )
                 self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 37135d431d..f72a23c584 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 "before/after rule not found: %s" % (relative_to_rule,)
             )
 
-        base_priority_class = res["priority_class"]
-        base_rule_priority = res["priority"]
+        base_priority_class, base_rule_priority = res
 
         if base_priority_class != priority_class:
             raise InconsistentRuleException(
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 56e8eb16a8..3484ce9ef9 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-        stream_ordering = int(res["stream_ordering"]) if res else None
-        rx_ts = res["received_ts"] if res else 0
+        stream_ordering = int(res[0]) if res else None
+        rx_ts = res[1] if res else 0
 
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 933d76e905..2c3f30e2eb 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     account timestamp as milliseconds since the epoch. None if the account
                     has not been renewed using the current token yet.
         """
-        ret_dict = await self.db_pool.simple_select_one(
-            table="account_validity",
-            keyvalues={"renewal_token": renewal_token},
-            retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
-            desc="get_user_from_renewal_token",
-        )
-
-        return (
-            ret_dict["user_id"],
-            ret_dict["expiration_ts_ms"],
-            ret_dict["token_used_ts_ms"],
+        return cast(
+            Tuple[str, int, Optional[int]],
+            await self.db_pool.simple_select_one(
+                table="account_validity",
+                keyvalues={"renewal_token": renewal_token},
+                retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
+                desc="get_user_from_renewal_token",
+            ),
         )
 
     async def get_renewal_token_for_user(self, user_id: str) -> str:
@@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 updatevalues={"shadow_banned": shadow_banned},
             )
             # In order for this to apply immediately, clear the cache for this user.
-            tokens = self.db_pool.simple_select_onecol_txn(
+            tokens = self.db_pool.simple_select_list_txn(
                 txn,
                 table="access_tokens",
                 keyvalues={"user_id": user_id},
-                retcol="token",
+                retcols=("token",),
+            )
+            self._invalidate_cache_and_stream_bulk(
+                txn, self.get_user_by_access_token, tokens
             )
-            for token in tokens:
-                self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (token,)
-                )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
         await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
@@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             user id, or None if no user id/threepid mapping exists
         """
-        ret = self.db_pool.simple_select_one_txn(
+        return self.db_pool.simple_select_one_onecol_txn(
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
-            ["user_id"],
+            "user_id",
             True,
         )
-        if ret:
-            return ret["user_id"]
-        return None
 
     async def user_add_threepid(
         self,
@@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         if res is None:
             return False
 
+        uses_allowed, pending, completed, expiry_time = res
+
         # Check if the token has expired
         now = self._clock.time_msec()
-        if res["expiry_time"] and res["expiry_time"] < now:
+        if expiry_time and expiry_time < now:
             return False
 
         # Check if the token has been used up
-        if (
-            res["uses_allowed"]
-            and res["pending"] + res["completed"] >= res["uses_allowed"]
-        ):
+        if uses_allowed and pending + completed >= uses_allowed:
             return False
 
         # Otherwise, the token is valid
@@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
             # about None not being indexable.
-            res = cast(
-                Dict[str, Any],
+            pending, completed = cast(
+                Tuple[int, int],
                 self.db_pool.simple_select_one_txn(
                     txn,
                     "registration_tokens",
@@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 "registration_tokens",
                 keyvalues={"token": token},
                 updatevalues={
-                    "completed": res["completed"] + 1,
-                    "pending": res["pending"] - 1,
+                    "completed": completed + 1,
+                    "pending": pending - 1,
                 },
             )
 
@@ -1585,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             A dict, or None if token doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "registration_tokens",
             keyvalues={"token": token},
             retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
             allow_none=True,
             desc="get_one_registration_token",
         )
+        if row is None:
+            return None
+        return {
+            "token": row[0],
+            "uses_allowed": row[1],
+            "pending": row[2],
+            "completed": row[3],
+            "expiry_time": row[4],
+        }
 
     async def generate_registration_token(
         self, length: int, chars: str
@@ -1714,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 return None
 
             # Get all info about the token so it can be sent in the response
-            return self.db_pool.simple_select_one_txn(
+            result = self.db_pool.simple_select_one_txn(
                 txn,
                 "registration_tokens",
                 keyvalues={"token": token},
@@ -1728,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 allow_none=True,
             )
 
+            if result is None:
+                return result
+
+            return {
+                "token": result[0],
+                "uses_allowed": result[1],
+                "pending": result[2],
+                "completed": result[3],
+                "expiry_time": result[4],
+            }
+
         return await self.db_pool.runInteraction(
             "update_registration_token", _update_registration_token_txn
         )
@@ -1939,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             keyvalues={"token": token},
             updatevalues={"used_ts": ts},
         )
-        user_id = values["user_id"]
-        expiry_ts = values["expiry_ts"]
-        used_ts = values["used_ts"]
-        auth_provider_id = values["auth_provider_id"]
-        auth_provider_session_id = values["auth_provider_session_id"]
+        (
+            user_id,
+            expiry_ts,
+            used_ts,
+            auth_provider_id,
+            auth_provider_session_id,
+        ) = values
 
         # Token was already used
         if used_ts is not None:
@@ -2668,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             )
             tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for token, _, _ in tokens_and_devices:
-                self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (token,)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self.get_user_by_access_token,
+                [(token,) for token, _, _ in tokens_and_devices],
+            )
 
             txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
@@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                     # reason, the next check is on the client secret, which is NOT NULL,
                     # so we don't have to worry about the client secret matching by
                     # accident.
-                    row = {"client_secret": None, "validated_at": None}
+                    row = None, None
                 else:
                     raise ThreepidValidationError("Unknown session_id")
 
-            retrieved_client_secret = row["client_secret"]
-            validated_at = row["validated_at"]
+            retrieved_client_secret, validated_at = row
 
             row = self.db_pool.simple_select_one_txn(
                 txn,
@@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 raise ThreepidValidationError(
                     "Validation token not found or has expired"
                 )
-            expires = row["expires"]
-            next_link = row["next_link"]
+            expires, next_link = row
 
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index afb880532e..ef26d5d9d3 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
-    async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
+    async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
         """Retrieve a room.
 
         Args:
             room_id: The ID of the room to retrieve.
         Returns:
-            A dict containing the room information, or None if the room is unknown.
+            A tuple containing the room information:
+                * True if the room is public
+                * True if the room has an auth chain index
+
+            or None if the room is unknown.
         """
-        return await self.db_pool.simple_select_one(
-            table="rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
-            desc="get_room",
-            allow_none=True,
+        row = cast(
+            Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
+            await self.db_pool.simple_select_one(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                retcols=("is_public", "has_auth_chain_index"),
+                desc="get_room",
+                allow_none=True,
+            ),
         )
+        if row is None:
+            return row
+        return bool(row[0]), bool(row[1])
 
     async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
         """Retrieve room with statistics.
@@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         )
 
         if row:
-            return RatelimitOverride(
-                messages_per_second=row["messages_per_second"],
-                burst_count=row["burst_count"],
-            )
+            return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
         else:
             return None
 
@@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         join.
         """
 
-        result = await self.db_pool.simple_select_one(
-            table="partial_state_rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("join_event_id", "device_lists_stream_id"),
-            desc="get_join_event_id_for_partial_state",
+        return cast(
+            Tuple[str, int],
+            await self.db_pool.simple_select_one(
+                table="partial_state_rooms",
+                keyvalues={"room_id": room_id},
+                retcols=("join_event_id", "device_lists_stream_id"),
+                desc="get_join_event_id_for_partial_state",
+            ),
         )
-        return result["join_event_id"], result["device_lists_stream_id"]
 
     def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
         return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1ed7f2d0ef..60d4a9ef30 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
                 "non-local user %s" % (user_id,),
             )
 
-        results_dict = await self.db_pool.simple_select_one(
-            "local_current_membership",
-            {"room_id": room_id, "user_id": user_id},
-            ("membership", "event_id"),
-            allow_none=True,
-            desc="get_local_current_membership_for_user_in_room",
+        results = cast(
+            Optional[Tuple[str, str]],
+            await self.db_pool.simple_select_one(
+                "local_current_membership",
+                {"room_id": room_id, "user_id": user_id},
+                ("membership", "event_id"),
+                allow_none=True,
+                desc="get_local_current_membership_for_user_in_room",
+            ),
         )
-        if not results_dict:
+        if not results:
             return None, None
 
-        return results_dict.get("membership"), results_dict.get("event_id")
+        return results
 
     @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2225f8272d..563c275a2c 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="get_position_for_event",
         )
 
-        return PersistedEventPosition(
-            row["instance_name"] or "master", row["stream_ordering"]
-        )
+        return PersistedEventPosition(row[1] or "master", row[0])
 
     async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
@@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
         )
-        return RoomStreamToken(
-            topological=row["topological_ordering"], stream=row["stream_ordering"]
-        )
+        return RoomStreamToken(topological=row[1], stream=row[0])
 
     async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
         """Gets the topological token in a room after or at the given stream
@@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = self.db_pool.simple_select_one_txn(
-            txn,
-            "events",
-            keyvalues={"event_id": event_id, "room_id": room_id},
-            retcols=["stream_ordering", "topological_ordering"],
+        stream_ordering, topological_ordering = cast(
+            Tuple[int, int],
+            self.db_pool.simple_select_one_txn(
+                txn,
+                "events",
+                keyvalues={"event_id": event_id, "room_id": room_id},
+                retcols=["stream_ordering", "topological_ordering"],
+            ),
         )
 
-        # This cannot happen as `allow_none=False`.
-        assert results is not None
-
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
-            topological=results["topological_ordering"] - 1,
-            stream=results["stream_ordering"],
+            topological=topological_ordering - 1, stream=stream_ordering
         )
 
         after_token = RoomStreamToken(
-            topological=results["topological_ordering"],
-            stream=results["stream_ordering"],
+            topological=topological_ordering, stream=stream_ordering
         )
 
         rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5555b53575..64543b4d61 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
 
         Returns: the task if available, `None` otherwise
         """
-        row = await self.db_pool.simple_select_one(
-            table="scheduled_tasks",
-            keyvalues={"id": id},
-            retcols=(
-                "id",
-                "action",
-                "status",
-                "timestamp",
-                "resource_id",
-                "params",
-                "result",
-                "error",
+        row = cast(
+            Optional[ScheduledTaskRow],
+            await self.db_pool.simple_select_one(
+                table="scheduled_tasks",
+                keyvalues={"id": id},
+                retcols=(
+                    "id",
+                    "action",
+                    "status",
+                    "timestamp",
+                    "resource_id",
+                    "params",
+                    "result",
+                    "error",
+                ),
+                allow_none=True,
+                desc="get_scheduled_task",
             ),
-            allow_none=True,
-            desc="get_scheduled_task",
         )
 
-        return (
-            TaskSchedulerWorkerStore._convert_row_to_task(
-                (
-                    row["id"],
-                    row["action"],
-                    row["status"],
-                    row["timestamp"],
-                    row["resource_id"],
-                    row["params"],
-                    row["result"],
-                    row["error"],
-                )
-            )
-            if row
-            else None
-        )
+        return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
 
     async def delete_scheduled_task(self, id: str) -> None:
         """Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index fecddb4144..2d341affaa 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             txn,
             table="received_transactions",
             keyvalues={"transaction_id": transaction_id, "origin": origin},
-            retcols=(
-                "transaction_id",
-                "origin",
-                "ts",
-                "response_code",
-                "response_json",
-                "has_been_referenced",
-            ),
+            retcols=("response_code", "response_json"),
             allow_none=True,
         )
 
-        if result and result["response_code"]:
-            return result["response_code"], db_to_json(result["response_json"])
+        # If the result exists and the response code is non-0.
+        if result and result[0]:
+            return result[0], db_to_json(result[1])
 
         else:
             return None
@@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
 
         # check we have a row and retry_last_ts is not null or zero
         # (retry_last_ts can't be negative)
-        if result and result["retry_last_ts"]:
-            return DestinationRetryTimings(**result)
+        if result and result[1]:
+            return DestinationRetryTimings(
+                failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
+            )
         else:
             return None
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 8ab7c42c4a..5b164fed8e 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
             desc="get_ui_auth_session",
         )
 
-        result["clientdict"] = db_to_json(result["clientdict"])
-
-        return UIAuthSessionData(session_id, **result)
+        return UIAuthSessionData(
+            session_id,
+            clientdict=db_to_json(result[0]),
+            uri=result[1],
+            method=result[2],
+            description=result[3],
+        )
 
     async def mark_ui_auth_stage_complete(
         self,
@@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ) -> None:
         # Get the current value.
-        result = cast(
-            Dict[str, Any],
-            self.db_pool.simple_select_one_txn(
-                txn,
-                table="ui_auth_sessions",
-                keyvalues={"session_id": session_id},
-                retcols=("serverdict",),
-            ),
+        result = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            retcol="serverdict",
         )
 
         # Update it and add it back to the database.
-        serverdict = db_to_json(result["serverdict"])
+        serverdict = db_to_json(result)
         serverdict[key] = value
 
         self.db_pool.simple_update_one_txn(
@@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
         Raises:
             StoreError if the session cannot be found.
         """
-        result = await self.db_pool.simple_select_one(
+        result = await self.db_pool.simple_select_one_onecol(
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
-            retcols=("serverdict",),
+            retcol="serverdict",
             desc="get_ui_auth_session_data",
         )
 
-        serverdict = db_to_json(result["serverdict"])
+        serverdict = db_to_json(result)
 
         return serverdict.get(key, default)
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d4b86ed7a6..93ec06904a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -20,7 +20,6 @@ from typing import (
     Collection,
     Iterable,
     List,
-    Mapping,
     Optional,
     Sequence,
     Set,
@@ -868,13 +867,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
-    async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
-        return await self.db_pool.simple_select_one(
-            table="user_directory",
-            keyvalues={"user_id": user_id},
-            retcols=("display_name", "avatar_url"),
-            allow_none=True,
-            desc="get_user_in_directory",
+    async def _get_user_in_directory(
+        self, user_id: str
+    ) -> Optional[Tuple[Optional[str], Optional[str]]]:
+        """
+        Fetch the user information in the user directory.
+
+        Returns:
+            None if the user is unknown, otherwise a tuple of display name and
+            avatar URL (both of which may be None).
+        """
+        return cast(
+            Optional[Tuple[Optional[str], Optional[str]]],
+            await self.db_pool.simple_select_one(
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+                retcols=("display_name", "avatar_url"),
+                allow_none=True,
+                desc="get_user_in_directory",
+            ),
         )
 
     async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 52d708ad17..76a74b8b8f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         next_id = self._load_next_id_txn(txn)
 
-        txn.call_after(self._mark_id_as_finished, next_id)
-        txn.call_on_exception(self._mark_id_as_finished, next_id)
+        txn.call_after(self._mark_ids_as_finished, [next_id])
+        txn.call_on_exception(self._mark_ids_as_finished, [next_id])
         txn.call_after(self._notifier.notify_replication)
 
         # Update the `stream_positions` table with newly updated stream
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         return self._return_factor * next_id
 
-    def _mark_id_as_finished(self, next_id: int) -> None:
-        """The ID has finished being processed so we should advance the
+    def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next_txn(txn)
+            # ... persist event ...
+        """
+
+        # If we have a list of instances that are allowed to write to this
+        # stream, make sure we're in it.
+        if self._writers and self._instance_name not in self._writers:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
+        next_ids = self._load_next_mult_id_txn(txn, n)
+
+        txn.call_after(self._mark_ids_as_finished, next_ids)
+        txn.call_on_exception(self._mark_ids_as_finished, next_ids)
+        txn.call_after(self._notifier.notify_replication)
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
+        return [self._return_factor * next_id for next_id in next_ids]
+
+    def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
+        """These IDs have finished being processed so we should advance the
         current position if possible.
         """
 
         with self._lock:
-            self._unfinished_ids.discard(next_id)
-            self._finished_ids.add(next_id)
+            self._unfinished_ids.difference_update(next_ids)
+            self._finished_ids.update(next_ids)
 
             new_cur: Optional[int] = None
 
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                     curr, new_cur, self._max_position_of_local_instance
                 )
 
-            self._add_persisted_position(next_id)
+            # TODO Can we call this for just the last position or somehow batch
+            # _add_persisted_position.
+            for next_id in next_ids:
+                self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
         return self.get_persisted_upto_position()
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
         exc: Optional[BaseException],
         tb: Optional[TracebackType],
     ) -> bool:
-        for i in self.stream_ids:
-            self.id_gen._mark_id_as_finished(i)
+        self.id_gen._mark_ids_as_finished(self.stream_ids)
 
         self.notifier.notify_replication()
 
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 76c56d5434..15e19b15fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
 
-        return self.get_success(
+        row = self.get_success(
             self.store.db_pool.simple_select_one(
                 table + "_current",
                 {id_col: stat_id},
@@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
+        return None if row is None else dict(zip(cols, row))
+
     def _perform_background_initial_update(self) -> None:
         # Do the initial population of the stats via the background update
         self._add_background_updates()
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index b5f15aa7d4..388447eea6 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
         profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == display_name)
+        self.assertTrue(profile[0] == display_name)
 
     def test_handle_local_profile_change_with_deactivated_user(self) -> None:
         # create user
@@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # profile is in directory
         profile = self.get_success(self.store._get_user_in_directory(r_user_id))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == display_name)
+        self.assertEqual(profile[0], display_name)
 
         # deactivate user
         self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 15f5d644e4..a8e7a76b29 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
         assert info is not None
-        file_id = info["filesystem_id"]
+        file_id = info.filesystem_id
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
             origin, file_id
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 278808abb5..dac79bd745 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
         # quarantining
         channel = self.make_request(
@@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["quarantined_by"])
+        self.assertTrue(media_info.quarantined_by)
 
         # remove from quarantine
         channel = self.make_request(
@@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
     def test_quarantine_protected_media(self) -> None:
         """
@@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify protection
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # quarantining
         channel = self.make_request(
@@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify that is not in quarantine
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
 
 class ProtectMediaByIDTestCase(_AdminMediaTests):
@@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
         # protect
         channel = self.make_request(
@@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # unprotect
         channel = self.make_request(
@@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
 
 class PurgeMediaCacheTestCase(_AdminMediaTests):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 37f37a09d8..42b065d883 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # is in user directory
         profile = self.get_success(self.store._get_user_in_directory(self.other_user))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == "User")
+        self.assertEqual(profile[0], "User")
 
         # Deactivate user
         channel = self.make_request(
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index cffbda9a7d..bd59bb50cf 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         #
         # Note that we don't have the UI Auth session ID, so just pull out the single
         # row.
-        ui_auth_data = self.get_success(
-            self.store.db_pool.simple_select_one(
-                "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+        result = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions", keyvalues={}, retcol="clientdict"
             )
         )
-        client_dict = db_to_json(ui_auth_data["clientdict"])
+        client_dict = db_to_json(result)
         self.assertNotIn("new_password", client_dict)
 
     @override_config({"rc_3pid_validation": {"burst_count": 3}})
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index ba4e017a0e..b04094b7b3 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertLessEqual(det_data.items(), channel.json_body.items())
 
         # Check the `completed` counter has been incremented and pending is 0
-        res = self.get_success(
+        pending, completed = self.get_success(
             store.db_pool.simple_select_one(
                 "registration_tokens",
                 keyvalues={"token": token},
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEqual(res["completed"], 1)
-        self.assertEqual(res["pending"], 0)
+        self.assertEqual(completed, 1)
+        self.assertEqual(pending, 0)
 
     @override_config({"registration_requires_token": True})
     def test_POST_registration_token_invalid(self) -> None:
@@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         params1["auth"]["type"] = LoginType.DUMMY
         self.make_request(b"POST", self.url, params1)
         # Check pending=0 and completed=1
-        res = self.get_success(
+        pending, completed = self.get_success(
             store.db_pool.simple_select_one(
                 "registration_tokens",
                 keyvalues={"token": token},
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEqual(res["pending"], 0)
-        self.assertEqual(res["completed"], 1)
+        self.assertEqual(pending, 0)
+        self.assertEqual(completed, 1)
 
         # Check auth still fails when using token with session2
         channel = self.make_request(b"POST", self.url, params2)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index b59d9dfd4d..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
         def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
             """Given an MXC URI, assert whether it has been purged or not."""
             if mxc_uri.server_name == self.hs.config.server.server_name:
-                found_media_dict = self.get_success(
-                    self.store.get_local_media(mxc_uri.media_id)
+                found_media = bool(
+                    self.get_success(self.store.get_local_media(mxc_uri.media_id))
                 )
             else:
-                found_media_dict = self.get_success(
-                    self.store.get_cached_remote_media(
-                        mxc_uri.server_name, mxc_uri.media_id
+                found_media = bool(
+                    self.get_success(
+                        self.store.get_cached_remote_media(
+                            mxc_uri.server_name, mxc_uri.media_id
+                        )
                     )
                 )
 
             if expect_purged:
-                self.assertIsNone(
-                    found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
-                )
+                self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
             else:
-                self.assertIsNotNone(
-                    found_media_dict,
+                self.assertTrue(
+                    found_media,
                     msg=f"{mxc_uri} unexpectedly purged",
                 )
 
diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
new file mode 100644
index 0000000000..3f71f5d102
--- /dev/null
+++ b/tests/storage/databases/main/test_cache.py
@@ -0,0 +1,117 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, call
+
+from synapse.storage.database import LoggingTransaction
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase
+
+
+class CacheInvalidationTestCase(HomeserverTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.store = self.hs.get_datastores().main
+
+    def test_bulk_invalidation(self) -> None:
+        master_invalidate = Mock()
+
+        self.store._get_cached_user_device.invalidate = master_invalidate
+
+        keys_to_invalidate = [
+            ("a", "b"),
+            ("c", "d"),
+            ("e", "f"),
+            ("g", "h"),
+        ]
+
+        def test_txn(txn: LoggingTransaction) -> None:
+            self.store._invalidate_cache_and_stream_bulk(
+                txn,
+                # This is an arbitrarily chosen cached store function. It was chosen
+                # because it takes more than one argument. We'll use this later to
+                # check that the invalidation was actioned over replication.
+                cache_func=self.store._get_cached_user_device,
+                key_tuples=keys_to_invalidate,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_invalidate_cache_and_stream_bulk", test_txn
+            )
+        )
+
+        master_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
+
+
+class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.store = self.hs.get_datastores().main
+
+    def test_bulk_invalidation_replicates(self) -> None:
+        """Like test_bulk_invalidation, but also checks the invalidations replicate."""
+        master_invalidate = Mock()
+        worker_invalidate = Mock()
+
+        self.store._get_cached_user_device.invalidate = master_invalidate
+        worker = self.make_worker_hs("synapse.app.generic_worker")
+        worker_ds = worker.get_datastores().main
+        worker_ds._get_cached_user_device.invalidate = worker_invalidate
+
+        keys_to_invalidate = [
+            ("a", "b"),
+            ("c", "d"),
+            ("e", "f"),
+            ("g", "h"),
+        ]
+
+        def test_txn(txn: LoggingTransaction) -> None:
+            self.store._invalidate_cache_and_stream_bulk(
+                txn,
+                # This is an arbitrarily chosen cached store function. It was chosen
+                # because it takes more than one argument. We'll use this later to
+                # check that the invalidation was actioned over replication.
+                cache_func=self.store._get_cached_user_device,
+                key_tuples=keys_to_invalidate,
+            )
+
+        assert self.store._cache_id_gen is not None
+        initial_token = self.store._cache_id_gen.get_current_token()
+        self.get_success(
+            self.database_pool.runInteraction(
+                "test_invalidate_cache_and_stream_bulk", test_txn
+            )
+        )
+        second_token = self.store._cache_id_gen.get_current_token()
+
+        self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
+
+        self.get_success(
+            worker.get_replication_data_handler().wait_for_stream_position(
+                "master", "caches", second_token
+            )
+        )
+
+        master_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
+        worker_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index f34b6b2dcf..491e6d5e63 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
+        self.assertEqual((1, 2, 3), ret)
         self.mock_txn.execute.assert_called_once_with(
             "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
         )
@@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertFalse(ret)
+        self.assertIsNone(ret)
 
     @defer.inlineCallbacks
     def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ce34195a25..d3ffe963d3 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
         )
 
     def test_get_room(self) -> None:
-        res = self.get_success(self.store.get_room(self.room.to_string()))
-        assert res is not None
-        self.assertLessEqual(
-            {
-                "room_id": self.room.to_string(),
-                "creator": self.u_creator.to_string(),
-                "is_public": True,
-            }.items(),
-            res.items(),
-        )
+        room = self.get_success(self.store.get_room(self.room.to_string()))
+        assert room is not None
+        self.assertTrue(room[0])
 
     def test_get_room_unknown_room(self) -> None:
         self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
diff --git a/tests/utils.py b/tests/utils.py
index 9be02b8ea7..c44e5cb4ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -83,11 +83,11 @@ def setupdb() -> None:
 
         # Set up in the db
         db_conn = db_engine.module.connect(
+            dbname=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
-            dbname=POSTGRES_BASE_DB,
         )
         logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(logging_conn, db_engine, None)