summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py3
-rw-r--r--synapse/api/errors.py2
-rw-r--r--synapse/app/generic_worker.py4
-rw-r--r--synapse/config/ratelimiting.py7
-rw-r--r--synapse/config/repository.py6
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/sender/__init__.py4
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/e2e_keys.py20
-rw-r--r--synapse/handlers/federation_event.py20
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/profile.py15
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/room_list.py43
-rw-r--r--synapse/handlers/room_member.py3
-rw-r--r--synapse/handlers/room_summary.py26
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/user_directory.py8
-rw-r--r--synapse/http/matrixfederationclient.py14
-rw-r--r--synapse/logging/opentracing.py7
-rw-r--r--synapse/media/_base.py6
-rw-r--r--synapse/media/media_repository.py284
-rw-r--r--synapse/media/url_previewer.py11
-rw-r--r--synapse/metrics/_reactor_metrics.py130
-rw-r--r--synapse/module_api/__init__.py3
-rw-r--r--synapse/module_api/callbacks/third_party_event_rules_callbacks.py3
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py56
-rw-r--r--synapse/replication/tcp/handler.py72
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/replication/tcp/resource.py17
-rw-r--r--synapse/replication/tcp/streams/_base.py18
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/media.py6
-rw-r--r--synapse/rest/admin/registration_tokens.py13
-rw-r--r--synapse/rest/admin/rooms.py19
-rw-r--r--synapse/rest/admin/users.py50
-rw-r--r--synapse/rest/client/account.py19
-rw-r--r--synapse/rest/client/directory.py2
-rw-r--r--synapse/rest/client/keys.py16
-rw-r--r--synapse/rest/media/create_resource.py83
-rw-r--r--synapse/rest/media/download_resource.py22
-rw-r--r--synapse/rest/media/media_repository_resource.py8
-rw-r--r--synapse/rest/media/thumbnail_resource.py81
-rw-r--r--synapse/rest/media/upload_resource.py75
-rw-r--r--synapse/storage/background_updates.py32
-rw-r--r--synapse/storage/controllers/persist_events.py252
-rw-r--r--synapse/storage/database.py77
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py52
-rw-r--r--synapse/storage/databases/main/account_data.py24
-rw-r--r--synapse/storage/databases/main/cache.py75
-rw-r--r--synapse/storage/databases/main/deviceinbox.py106
-rw-r--r--synapse/storage/databases/main/devices.py100
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py31
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py114
-rw-r--r--synapse/storage/databases/main/event_federation.py24
-rw-r--r--synapse/storage/databases/main/events.py415
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py19
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/keys.py17
-rw-r--r--synapse/storage/databases/main/media_repository.py249
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/presence.py9
-rw-r--r--synapse/storage/databases/main/profile.py28
-rw-r--r--synapse/storage/databases/main/purge_events.py33
-rw-r--r--synapse/storage/databases/main/push_rule.py94
-rw-r--r--synapse/storage/databases/main/receipts.py4
-rw-r--r--synapse/storage/databases/main/registration.py149
-rw-r--r--synapse/storage/databases/main/room.py129
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/search.py12
-rw-r--r--synapse/storage/databases/main/stream.py30
-rw-r--r--synapse/storage/databases/main/task_scheduler.py48
-rw-r--r--synapse/storage/databases/main/transactions.py20
-rw-r--r--synapse/storage/databases/main/ui_auth.py31
-rw-r--r--synapse/storage/databases/main/user_directory.py27
-rw-r--r--synapse/storage/databases/state/bg_updates.py4
-rw-r--r--synapse/storage/engines/postgres.py3
-rw-r--r--synapse/storage/schema/__init__.py3
-rw-r--r--synapse/storage/schema/main/delta/54/delete_forward_extremities.sql2
-rw-r--r--synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql3
-rw-r--r--synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql3
-rw-r--r--synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql15
-rw-r--r--synapse/storage/util/id_generators.py56
-rw-r--r--synapse/util/__init__.py2
-rw-r--r--synapse/util/async_helpers.py14
-rw-r--r--synapse/util/check_dependencies.py3
-rw-r--r--synapse/util/iterutils.py51
-rw-r--r--synapse/util/task_scheduler.py4
91 files changed, 2449 insertions, 1154 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index ef8590db65..75fe0183f6 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -348,8 +348,7 @@ class Porter:
                     backward_chunk = 0
                     already_ported = 0
             else:
-                forward_chunk = row["forward_rowid"]
-                backward_chunk = row["backward_rowid"]
+                forward_chunk, backward_chunk = row
 
             if total_to_port is None:
                 already_ported, total_to_port = await self._get_total_count_to_port(
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index fdb2955be8..fbd8b16ec3 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -83,6 +83,8 @@ class Codes(str, Enum):
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
     # USER_LOCKED = "M_USER_LOCKED"
     USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
+    NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
+    CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
 
     # Part of MSC3848
     # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f7c80eee21..bcfb7a7200 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -104,8 +104,8 @@ logger = logging.getLogger("synapse.app.generic_worker")
 
 
 class GenericWorkerStore(
-    # FIXME(#3714): We need to add UserDirectoryStore as we write directly
-    # rather than going via the correct worker.
+    # FIXME(https://github.com/matrix-org/synapse/issues/3714): We need to add
+    # UserDirectoryStore as we write directly rather than going via the correct worker.
     UserDirectoryStore,
     StatsStore,
     UIAuthWorkerStore,
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4efbaeac0d..b1fcaf71a3 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -204,3 +204,10 @@ class RatelimitConfig(Config):
             "rc_third_party_invite",
             defaults={"per_second": 0.0025, "burst_count": 5},
         )
+
+        # Ratelimit create media requests:
+        self.rc_media_create = RatelimitSettings.parse(
+            config,
+            "rc_media_create",
+            defaults={"per_second": 10, "burst_count": 50},
+        )
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index f6cfdd3e04..839c026d70 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config):
             "prevent_media_downloads_from", []
         )
 
+        self.unused_expiration_time = self.parse_duration(
+            config.get("unused_expiration_time", "24h")
+        )
+
+        self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5)
+
         self.media_store_path = self.ensure_directory(
             config.get("media_store_path", "media_store")
         )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8e3064c7e7..2bb2c64ebe 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.lock import Lock
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
-from synapse.types import JsonDict, StateMap, get_domain_from_id, UserID
+from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 7980d1a322..948fde6658 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -581,14 +581,14 @@ class FederationSender(AbstractFederationSender):
                                 "get_joined_hosts", str(sg)
                             )
                             if destinations is None:
-                                # Add logging to help track down #13444
+                                # Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
                                 logger.info(
                                     "Unexpectedly did not have cached destinations for %s / %s",
                                     sg,
                                     event.event_id,
                                 )
                         else:
-                            # Add logging to help track down #13444
+                            # Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
                             logger.info(
                                 "Unexpectedly did not have cached prev group for %s",
                                 event.event_id,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2c2baeac67..d06f8e3296 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -283,7 +283,7 @@ class AdminHandler:
                 start, limit, user_id
             )
             for media in media_ids:
-                writer.write_media_id(media["media_id"], media)
+                writer.write_media_id(media.media_id, attr.asdict(media))
 
             logger.info(
                 "[%s] Written %d media_ids of %s",
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 93472d0117..98e6e42563 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -383,7 +383,7 @@ class DeviceWorkerHandler:
         )
 
     DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
-    DEVICE_MSGS_DELETE_SLEEP_MS = 1000
+    DEVICE_MSGS_DELETE_SLEEP_MS = 100
 
     async def _delete_device_messages(
         self,
@@ -396,15 +396,17 @@ class DeviceWorkerHandler:
         up_to_stream_id = task.params["up_to_stream_id"]
 
         # Delete the messages in batches to avoid too much DB load.
+        from_stream_id = None
         while True:
-            res = await self.store.delete_messages_for_device(
+            from_stream_id, _ = await self.store.delete_messages_for_device_between(
                 user_id=user_id,
                 device_id=device_id,
-                up_to_stream_id=up_to_stream_id,
+                from_stream_id=from_stream_id,
+                to_stream_id=up_to_stream_id,
                 limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
             )
 
-            if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+            if from_stream_id is None:
                 return TaskStatus.COMPLETE, None, None
 
             await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d06524495f..70fa931d17 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1450,19 +1450,25 @@ class E2eKeysHandler:
 
         return desired_key_data
 
-    async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
+    async def check_cross_signing_setup(self, user_id: str) -> Tuple[bool, bool]:
         """Checks if the user has cross-signing set up
 
         Args:
             user_id: The user to check
 
-        Returns:
-            True if the user has cross-signing set up, False otherwise
+        Returns: a 2-tuple of booleans
+            - whether the user has cross-signing set up, and
+            - whether the user's master cross-signing key may be replaced without UIA.
         """
-        existing_master_key = await self.store.get_e2e_cross_signing_key(
-            user_id, "master"
-        )
-        return existing_master_key is not None
+        (
+            exists,
+            ts_replacable_without_uia_before,
+        ) = await self.store.get_master_cross_signing_key_updatable_before(user_id)
+
+        if ts_replacable_without_uia_before is None:
+            return exists, False
+        else:
+            return exists, self.clock.time_msec() < ts_replacable_without_uia_before
 
 
 def _check_cross_signing_key(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 0cc8e990d9..f4c17894aa 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -88,7 +88,7 @@ from synapse.types import (
 )
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.iterutils import batch_iter, partition
+from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
 
@@ -748,7 +748,7 @@ class FederationEventHandler:
         # fetching fresh state for the room if the missing event
         # can't be found, which slightly reduces our security.
         # it may also increase our DAG extremity count for the room,
-        # causing additional state resolution?  See #1760.
+        # causing additional state resolution?  See https://github.com/matrix-org/synapse/issues/1760.
         # However, fetching state doesn't hold the linearizer lock
         # apparently.
         #
@@ -1669,14 +1669,13 @@ class FederationEventHandler:
 
         # XXX: it might be possible to kick this process off in parallel with fetching
         # the events.
-        while event_map:
-            # build a list of events whose auth events are not in the queue.
-            roots = tuple(
-                ev
-                for ev in event_map.values()
-                if not any(aid in event_map for aid in ev.auth_event_ids())
-            )
 
+        # We need to persist an event's auth events before the event.
+        auth_graph = {
+            ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map]
+            for ev in event_map.values()
+        }
+        for roots in sorted_topologically_batched(event_map.values(), auth_graph):
             if not roots:
                 # if *none* of the remaining events are ready, that means
                 # we have a loop. This either means a bug in our logic, or that
@@ -1698,9 +1697,6 @@ class FederationEventHandler:
 
             await self._auth_and_persist_outliers_inner(room_id, roots)
 
-            for ev in roots:
-                del event_map[ev.event_id]
-
     async def _auth_and_persist_outliers_inner(
         self, room_id: str, fetched_events: Collection[EventBase]
     ) -> None:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 202beee738..4137fd50b1 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1816,7 +1816,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 # the same token repeatedly.
                 #
                 # Hence this guard where we just return nothing so that the sync
-                # doesn't return. C.f. #5503.
+                # doesn't return. C.f. https://github.com/matrix-org/synapse/issues/5503.
                 return [], max_token
 
             # Figure out which other users this user should explicitly receive
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union
 
 from synapse.api.errors import (
     AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
             server_name = host
 
         if self._is_mine_server_name(server_name):
-            media_info = await self.store.get_local_media(media_id)
+            media_info: Optional[
+                Union[LocalMedia, RemoteMedia]
+            ] = await self.store.get_local_media(media_id)
         else:
             media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
@@ -322,12 +325,12 @@ class ProfileHandler:
 
         if self.max_avatar_size:
             # Ensure avatar does not exceed max allowed avatar size
-            if media_info["media_length"] > self.max_avatar_size:
+            if media_info.media_length > self.max_avatar_size:
                 logger.warning(
                     "Forbidding avatar change to %s: %d bytes is above the allowed size "
                     "limit",
                     mxc,
-                    media_info["media_length"],
+                    media_info.media_length,
                 )
                 return False
 
@@ -335,12 +338,12 @@ class ProfileHandler:
             # Ensure the avatar's file type is allowed
             if (
                 self.allowed_avatar_mimetypes
-                and media_info["media_type"] not in self.allowed_avatar_mimetypes
+                and media_info.media_type not in self.allowed_avatar_mimetypes
             ):
                 logger.warning(
                     "Forbidding avatar change to %s: mimetype %s not allowed",
                     mxc,
-                    media_info["media_type"],
+                    media_info.media_type,
                 )
                 return False
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6d680b0795..afd8138caf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -269,7 +269,7 @@ class RoomCreationHandler:
         self,
         requester: Requester,
         old_room_id: str,
-        old_room: Dict[str, Any],
+        old_room: Tuple[bool, str, bool],
         new_room_id: str,
         new_version: RoomVersion,
         tombstone_event: EventBase,
@@ -279,7 +279,7 @@ class RoomCreationHandler:
         Args:
             requester: the user requesting the upgrade
             old_room_id: the id of the room to be replaced
-            old_room: a dict containing room information for the room to be replaced,
+            old_room: a tuple containing room information for the room to be replaced,
                 as returned by `RoomWorkerStore.get_room`.
             new_room_id: the id of the replacement room
             new_version: the version to upgrade the room to
@@ -299,7 +299,7 @@ class RoomCreationHandler:
         await self.store.store_room(
             room_id=new_room_id,
             room_creator_user_id=user_id,
-            is_public=old_room["is_public"],
+            is_public=old_room[0],
             room_version=new_version,
         )
 
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..2947e154be 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.storage.databases.main.room import LargestRoomStats
 from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
 from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.caches.response_cache import ResponseCache
@@ -170,26 +171,24 @@ class RoomListHandler:
             ignore_non_federatable=from_federation,
         )
 
-        def build_room_entry(room: JsonDict) -> JsonDict:
+        def build_room_entry(room: LargestRoomStats) -> JsonDict:
             entry = {
-                "room_id": room["room_id"],
-                "name": room["name"],
-                "topic": room["topic"],
-                "canonical_alias": room["canonical_alias"],
-                "num_joined_members": room["joined_members"],
-                "avatar_url": room["avatar"],
-                "world_readable": room["history_visibility"]
+                "room_id": room.room_id,
+                "name": room.name,
+                "topic": room.topic,
+                "canonical_alias": room.canonical_alias,
+                "num_joined_members": room.joined_members,
+                "avatar_url": room.avatar,
+                "world_readable": room.history_visibility
                 == HistoryVisibility.WORLD_READABLE,
-                "guest_can_join": room["guest_access"] == "can_join",
-                "join_rule": room["join_rules"],
-                "room_type": room["room_type"],
+                "guest_can_join": room.guest_access == "can_join",
+                "join_rule": room.join_rules,
+                "room_type": room.room_type,
             }
 
             # Filter out Nones – rather omit the field altogether
             return {k: v for k, v in entry.items() if v is not None}
 
-        results = [build_room_entry(r) for r in results]
-
         response: JsonDict = {}
         num_results = len(results)
         if limit is not None:
@@ -212,33 +211,33 @@ class RoomListHandler:
                     # If there was a token given then we assume that there
                     # must be previous results.
                     response["prev_batch"] = RoomListNextBatch(
-                        last_joined_members=initial_entry["num_joined_members"],
-                        last_room_id=initial_entry["room_id"],
+                        last_joined_members=initial_entry.joined_members,
+                        last_room_id=initial_entry.room_id,
                         direction_is_forward=False,
                     ).to_token()
 
                 if more_to_come:
                     response["next_batch"] = RoomListNextBatch(
-                        last_joined_members=final_entry["num_joined_members"],
-                        last_room_id=final_entry["room_id"],
+                        last_joined_members=final_entry.joined_members,
+                        last_room_id=final_entry.room_id,
                         direction_is_forward=True,
                     ).to_token()
             else:
                 if has_batch_token:
                     response["next_batch"] = RoomListNextBatch(
-                        last_joined_members=final_entry["num_joined_members"],
-                        last_room_id=final_entry["room_id"],
+                        last_joined_members=final_entry.joined_members,
+                        last_room_id=final_entry.room_id,
                         direction_is_forward=True,
                     ).to_token()
 
                 if more_to_come:
                     response["prev_batch"] = RoomListNextBatch(
-                        last_joined_members=initial_entry["num_joined_members"],
-                        last_room_id=initial_entry["room_id"],
+                        last_joined_members=initial_entry.joined_members,
+                        last_room_id=initial_entry.room_id,
                         direction_is_forward=False,
                     ).to_token()
 
-        response["chunk"] = results
+        response["chunk"] = [build_room_entry(r) for r in results]
 
         response["total_room_count_estimate"] = await self.store.count_public_rooms(
             network_tuple,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 918eb203e2..eddc2af9ba 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Add new room to the room directory if the old room was there
         # Remove old room from the room directory
         old_room = await self.store.get_room(old_room_id)
-        if old_room is not None and old_room["is_public"]:
+        # If the old room exists and is public.
+        if old_room is not None and old_room[0]:
             await self.store.set_room_is_public(old_room_id, False)
             await self.store.set_room_is_public(room_id, True)
 
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index dd559b4c45..1dfb12e065 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -703,24 +703,24 @@ class RoomSummaryHandler:
         # there should always be an entry
         assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
 
-        entry = {
-            "room_id": stats["room_id"],
-            "name": stats["name"],
-            "topic": stats["topic"],
-            "canonical_alias": stats["canonical_alias"],
-            "num_joined_members": stats["joined_members"],
-            "avatar_url": stats["avatar"],
-            "join_rule": stats["join_rules"],
+        entry: JsonDict = {
+            "room_id": stats.room_id,
+            "name": stats.name,
+            "topic": stats.topic,
+            "canonical_alias": stats.canonical_alias,
+            "num_joined_members": stats.joined_members,
+            "avatar_url": stats.avatar,
+            "join_rule": stats.join_rules,
             "world_readable": (
-                stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
+                stats.history_visibility == HistoryVisibility.WORLD_READABLE
             ),
-            "guest_can_join": stats["guest_access"] == "can_join",
-            "room_type": stats["room_type"],
+            "guest_can_join": stats.guest_access == "can_join",
+            "room_type": stats.room_type,
         }
 
         if self._msc3266_enabled:
-            entry["im.nheko.summary.version"] = stats["version"]
-            entry["im.nheko.summary.encryption"] = stats["encryption"]
+            entry["im.nheko.summary.version"] = stats.version
+            entry["im.nheko.summary.encryption"] = stats.encryption
 
         # Federation requests need to provide additional information so the
         # requested server is able to filter the response appropriately.
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler:
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
                     media = await self._media_repo.store.get_local_media(media_id)
-                    if media is not None and upload_name == media["upload_name"]:
+                    if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2f1bc5a015..bf0106c6e7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -399,7 +399,7 @@ class SyncHandler:
         #
         # If that happens, we mustn't cache it, so that when the client comes back
         # with the same cache token, we don't immediately return the same empty
-        # result, causing a tightloop. (#8518)
+        # result, causing a tightloop. (https://github.com/matrix-org/synapse/issues/8518)
         if result.next_batch == since_token:
             cache_context.should_cache = False
 
@@ -1003,7 +1003,7 @@ class SyncHandler:
                     # always make sure we LL ourselves so we know we're in the room
                     # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
                     # We only need apply this on full state syncs given we disabled
-                    # LL for incr syncs in #3840.
+                    # LL for incr syncs in https://github.com/matrix-org/synapse/pull/3840.
                     # We don't insert ourselves into `members_to_fetch`, because in some
                     # rare cases (an empty event batch with a now_token after the user's
                     # leave in a partial state room which another local user has
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 75717ba4f9..3c19ea56f8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -184,8 +184,8 @@ class UserDirectoryHandler(StateDeltasHandler):
         """Called to update index of our local user profiles when they change
         irrespective of any rooms the user may be in.
         """
-        # FIXME(#3714): We should probably do this in the same worker as all
-        # the other changes.
+        # FIXME(https://github.com/matrix-org/synapse/issues/3714): We should
+        # probably do this in the same worker as all the other changes.
 
         if await self.store.should_include_local_user_in_dir(user_id):
             await self.store.update_profile_in_user_dir(
@@ -194,8 +194,8 @@ class UserDirectoryHandler(StateDeltasHandler):
 
     async def handle_local_user_deactivated(self, user_id: str) -> None:
         """Called when a user ID is deactivated"""
-        # FIXME(#3714): We should probably do this in the same worker as all
-        # the other changes.
+        # FIXME(https://github.com/matrix-org/synapse/issues/3714): We should
+        # probably do this in the same worker as all the other changes.
         await self.store.remove_from_user_dir(user_id)
 
     async def _unsafe_process(self) -> None:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 08c7fc1631..d5013e8e97 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -465,7 +465,7 @@ class MatrixFederationHttpClient:
         """Wrapper for _send_request which can optionally retry the request
         upon receiving a combination of a 400 HTTP response code and a
         'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
-        due to #3622.
+        due to https://github.com/matrix-org/synapse/issues/3622.
 
         Args:
             request: details of request to be sent
@@ -958,9 +958,9 @@ class MatrixFederationHttpClient:
                 requests).
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end
-                of the request. Workaround for #3622 in Synapse <= v0.99.3. This
-                will be attempted before backing off if backing off has been
-                enabled.
+                of the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
+                in Synapse <= v0.99.3. This will be attempted before backing off if
+                backing off has been enabled.
             parser: The parser to use to decode the response. Defaults to
                 parsing as JSON.
             backoff_on_all_error_codes: Back off if we get any error response
@@ -1155,7 +1155,8 @@ class MatrixFederationHttpClient:
 
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
-                the request. Workaround for #3622 in Synapse <= v0.99.3.
+                the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
+                in Synapse <= v0.99.3.
 
             parser: The parser to use to decode the response. Defaults to
                 parsing as JSON.
@@ -1250,7 +1251,8 @@ class MatrixFederationHttpClient:
 
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
-                the request. Workaround for #3622 in Synapse <= v0.99.3.
+                the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
+                in Synapse <= v0.99.3.
 
             parser: The parser to use to decode the response. Defaults to
                 parsing as JSON.
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 4454fe29a5..e297fa9c8b 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -1019,11 +1019,14 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     if not opentracing:
         return func
 
+    # getfullargspec is somewhat expensive, so ensure it is only called a single
+    # time (the function signature shouldn't change anyway).
+    argspec = inspect.getfullargspec(func)
+
     @contextlib.contextmanager
     def _wrapping_logic(
-        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+        _func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
     ) -> Generator[None, None, None]:
-        argspec = inspect.getfullargspec(func)
         # We use `[1:]` to skip the `self` object reference and `start=1` to
         # make the index line up with `argspec.args`.
         #
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 860e5ddca2..9d88a711cf 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
     "audio/x-flac",
 ]
 
+# Default timeout_ms for download and thumbnail requests
+DEFAULT_MAX_TIMEOUT_MS = 20_000
+
+# Maximum allowed timeout_ms for download and thumbnail requests
+MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
+
 
 def respond_404(request: SynapseRequest) -> None:
     assert request.path is not None
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..bf976b9e7c 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
 from io import BytesIO
 from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
+import attr
 from matrix_common.types.mxc_uri import MXCUri
 
 import twisted.internet.error
@@ -26,13 +27,16 @@ import twisted.web.http
 from twisted.internet.defer import Deferred
 
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     NotFoundError,
     RequestSendFailed,
     SynapseError,
+    cs_error,
 )
 from synapse.config.repository import ThumbnailRequirement
+from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.logging.opentracing import trace
@@ -50,6 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -78,6 +83,8 @@ class MediaRepository:
         self.store = hs.get_datastores().main
         self.max_upload_size = hs.config.media.max_upload_size
         self.max_image_pixels = hs.config.media.max_image_pixels
+        self.unused_expiration_time = hs.config.media.unused_expiration_time
+        self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
 
         Thumbnailer.set_limits(self.max_image_pixels)
 
@@ -184,6 +191,117 @@ class MediaRepository:
             self.recently_accessed_locals.add(media_id)
 
     @trace
+    async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
+        """Create and store a media ID for a local user and return the MXC URI and its
+        expiration.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple containing the MXC URI of the stored content and the timestamp at
+            which the MXC URI expires.
+        """
+        media_id = random_string(24)
+        now = self.clock.time_msec()
+        await self.store.store_local_media_id(
+            media_id=media_id,
+            time_now_ms=now,
+            user_id=auth_user,
+        )
+        return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
+
+    @trace
+    async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
+        """Check if the user is over the limit for pending media uploads.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple with a boolean and an integer indicating whether the user has too
+            many pending media uploads and the timestamp at which the first pending
+            media will expire, respectively.
+        """
+        pending, first_expiration_ts = await self.store.count_pending_media(
+            user_id=auth_user
+        )
+        return pending >= self.max_pending_media_uploads, first_expiration_ts
+
+    @trace
+    async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
+        """Verify that the media ID can be uploaded to by the given user. This
+        function checks that:
+
+        * the media ID exists
+        * the media ID does not already have content
+        * the user uploading is the same as the one who created the media ID
+        * the media ID has not expired
+
+        Args:
+            media_id: The media ID to verify
+            auth_user: The user_id of the uploader
+        """
+        media = await self.store.get_local_media(media_id)
+        if media is None:
+            raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
+
+        if media.user_id != auth_user.to_string():
+            raise SynapseError(
+                403,
+                "Only the creator of the media ID can upload to it",
+                errcode=Codes.FORBIDDEN,
+            )
+
+        if media.media_length is not None:
+            raise SynapseError(
+                409,
+                "Media ID already has content",
+                errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+            )
+
+        expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
+        if media.created_ts < expired_time_ms:
+            raise NotFoundError("Media ID has expired")
+
+    @trace
+    async def update_content(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        content: IO,
+        content_length: int,
+        auth_user: UserID,
+    ) -> None:
+        """Update the content of the given media ID.
+
+        Args:
+            media_id: The media ID to replace.
+            media_type: The content type of the file.
+            upload_name: The name of the file, if provided.
+            content: A file like object that is the content to store
+            content_length: The length of the content
+            auth_user: The user_id of the uploader
+        """
+        file_info = FileInfo(server_name=None, file_id=media_id)
+        fname = await self.media_storage.store_file(content, file_info)
+        logger.info("Stored local media in file %r", fname)
+
+        await self.store.update_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+
+        try:
+            await self._generate_thumbnails(None, media_id, media_id, media_type)
+        except Exception as e:
+            logger.info("Failed to generate thumbnails: %s", e)
+
+    @trace
     async def create_content(
         self,
         media_type: str,
@@ -229,8 +347,74 @@ class MediaRepository:
 
         return MXCUri(self.server_name, media_id)
 
+    def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
+        respond_with_json(
+            request,
+            504,
+            cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
+            send_cors=True,
+        )
+
+    async def get_local_media_info(
+        self, request: SynapseRequest, media_id: str, max_timeout_ms: int
+    ) -> Optional[LocalMedia]:
+        """Gets the info dictionary for given local media ID. If the media has
+        not been uploaded yet, this function will wait up to ``max_timeout_ms``
+        milliseconds for the media to be uploaded.
+
+        Args:
+            request: The incoming request.
+            media_id: The media ID of the content. (This is the same as
+                the file_id for local content.)
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
+
+        Returns:
+            Either the info dictionary for the given local media ID or
+            ``None``. If ``None``, then no further processing is necessary as
+            this function will send the necessary JSON response.
+        """
+        wait_until = self.clock.time_msec() + max_timeout_ms
+        while True:
+            # Get the info for the media
+            media_info = await self.store.get_local_media(media_id)
+            if not media_info:
+                logger.info("Media %s is unknown", media_id)
+                respond_404(request)
+                return None
+
+            if media_info.quarantined_by:
+                logger.info("Media %s is quarantined", media_id)
+                respond_404(request)
+                return None
+
+            # The file has been uploaded, so stop looping
+            if media_info.media_length is not None:
+                return media_info
+
+            # Check if the media ID has expired and still hasn't been uploaded to.
+            now = self.clock.time_msec()
+            expired_time_ms = now - self.unused_expiration_time
+            if media_info.created_ts < expired_time_ms:
+                logger.info("Media %s has expired without being uploaded", media_id)
+                respond_404(request)
+                return None
+
+            if now >= wait_until:
+                break
+
+            await self.clock.sleep(0.5)
+
+        logger.info("Media %s has not yet been uploaded", media_id)
+        self.respond_not_yet_uploaded(request)
+        return None
+
     async def get_local_media(
-        self, request: SynapseRequest, media_id: str, name: Optional[str]
+        self,
+        request: SynapseRequest,
+        media_id: str,
+        name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Responds to requests for local media, if exists, or returns 404.
 
@@ -240,23 +424,24 @@ class MediaRepository:
                 the file_id for local content.)
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
         """
-        media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
-            respond_404(request)
+        media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
+        if not media_info:
             return
 
         self.mark_recently_accessed(None, media_id)
 
-        media_type = media_info["media_type"]
+        media_type = media_info.media_type
         if not media_type:
             media_type = "application/octet-stream"
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        url_cache = media_info["url_cache"]
+        media_length = media_info.media_length
+        upload_name = name if name else media_info.upload_name
+        url_cache = media_info.url_cache
 
         file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
@@ -271,6 +456,7 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -280,6 +466,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -305,27 +493,33 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # We deliberately stream the file outside the lock
-        if responder:
-            media_type = media_info["media_type"]
-            media_length = media_info["media_length"]
-            upload_name = name if name else media_info["upload_name"]
+        if responder and media_info:
+            upload_name = name if name else media_info.upload_name
             await respond_with_responder(
-                request, responder, media_type, media_length, upload_name
+                request,
+                responder,
+                media_info.media_type,
+                media_info.media_length,
+                upload_name,
             )
         else:
             respond_404(request)
 
-    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+    async def get_remote_media_info(
+        self, server_name: str, media_id: str, max_timeout_ms: int
+    ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
         Args:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file
@@ -341,7 +535,7 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -352,8 +546,8 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str
-    ) -> Tuple[Optional[Responder], dict]:
+        self, server_name: str, media_id: str, max_timeout_ms: int
+    ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -361,6 +555,8 @@ class MediaRepository:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the
                 remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -373,15 +569,17 @@ class MediaRepository:
 
         # If we have an entry in the DB, try and look for it
         if media_info:
-            file_id = media_info["filesystem_id"]
+            file_id = media_info.filesystem_id
             file_info = FileInfo(server_name, file_id)
 
-            if media_info["quarantined_by"]:
+            if media_info.quarantined_by:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
 
-            if not media_info["media_type"]:
-                media_info["media_type"] = "application/octet-stream"
+            if not media_info.media_type:
+                media_info = attr.evolve(
+                    media_info, media_type="application/octet-stream"
+                )
 
             responder = await self.media_storage.fetch_media(file_info)
             if responder:
@@ -391,8 +589,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name,
-                media_id,
+                server_name, media_id, max_timeout_ms
             )
         except SynapseError:
             raise
@@ -403,9 +600,9 @@ class MediaRepository:
             if not media_info:
                 raise e
 
-        file_id = media_info["filesystem_id"]
-        if not media_info["media_type"]:
-            media_info["media_type"] = "application/octet-stream"
+        file_id = media_info.filesystem_id
+        if not media_info.media_type:
+            media_info = attr.evolve(media_info, media_type="application/octet-stream")
         file_info = FileInfo(server_name, file_id)
 
         # We generate thumbnails even if another process downloaded the media
@@ -415,7 +612,7 @@ class MediaRepository:
         # otherwise they'll request thumbnails and get a 404 if they're not
         # ready yet.
         await self._generate_thumbnails(
-            server_name, media_id, file_id, media_info["media_type"]
+            server_name, media_id, file_id, media_info.media_type
         )
 
         responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +622,8 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
-    ) -> dict:
+        max_timeout_ms: int,
+    ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -434,7 +632,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the
                 remote server). This is different than the file_id, which is
                 locally generated.
-            file_id: Local file ID
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file.
@@ -458,7 +657,8 @@ class MediaRepository:
                         # tell the remote server to 404 if it doesn't
                         # recognise the server_name, to make sure we don't
                         # end up with a routing loop.
-                        "allow_remote": "false"
+                        "allow_remote": "false",
+                        "timeout_ms": str(max_timeout_ms),
                     },
                 )
             except RequestSendFailed as e:
@@ -518,7 +718,7 @@ class MediaRepository:
                 origin=server_name,
                 media_id=media_id,
                 media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
+                time_now_ms=time_now_ms,
                 upload_name=upload_name,
                 media_length=length,
                 filesystem_id=file_id,
@@ -526,15 +726,17 @@ class MediaRepository:
 
         logger.info("Stored remote media in file %r", fname)
 
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        return media_info
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
 
     def _get_thumbnail_requirements(
         self, media_type: str
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer:
         cache_result = await self.store.get_url_cache(url, ts)
         if (
             cache_result
-            and cache_result["expires_ts"] > ts
-            and cache_result["response_code"] / 100 == 2
+            and cache_result.expires_ts > ts
+            and cache_result.response_code // 100 == 2
         ):
             # It may be stored as text in the database, not as bytes (such as
             # PostgreSQL). If so, encode it back before handing it on.
-            og = cache_result["og"]
-            if isinstance(og, str):
-                og = og.encode("utf8")
-            return og
+            if isinstance(cache_result.og, str):
+                return cache_result.og.encode("utf8")
+            return cache_result.og
 
         # If this URL can be accessed via an allowed oEmbed, use that instead.
         url_to_download = url
diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py
index a2c6e6842d..dd486dd3e2 100644
--- a/synapse/metrics/_reactor_metrics.py
+++ b/synapse/metrics/_reactor_metrics.py
@@ -12,17 +12,45 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import select
+import logging
 import time
-from typing import Any, Iterable, List, Tuple
+from selectors import SelectSelector, _PollLikeSelector  # type: ignore[attr-defined]
+from typing import Any, Callable, Iterable
 
 from prometheus_client import Histogram, Metric
 from prometheus_client.core import REGISTRY, GaugeMetricFamily
 
-from twisted.internet import reactor
+from twisted.internet import reactor, selectreactor
+from twisted.internet.asyncioreactor import AsyncioSelectorReactor
 
 from synapse.metrics._types import Collector
 
+try:
+    from selectors import KqueueSelector
+except ImportError:
+
+    class KqueueSelector:  # type: ignore[no-redef]
+        pass
+
+
+try:
+    from twisted.internet.epollreactor import EPollReactor
+except ImportError:
+
+    class EPollReactor:  # type: ignore[no-redef]
+        pass
+
+
+try:
+    from twisted.internet.pollreactor import PollReactor
+except ImportError:
+
+    class PollReactor:  # type: ignore[no-redef]
+        pass
+
+
+logger = logging.getLogger(__name__)
+
 #
 # Twisted reactor metrics
 #
@@ -34,52 +62,100 @@ tick_time = Histogram(
 )
 
 
-class EpollWrapper:
-    """a wrapper for an epoll object which records the time between polls"""
+class CallWrapper:
+    """A wrapper for a callable which records the time between calls"""
 
-    def __init__(self, poller: "select.epoll"):  # type: ignore[name-defined]
+    def __init__(self, wrapped: Callable[..., Any]):
         self.last_polled = time.time()
-        self._poller = poller
+        self._wrapped = wrapped
 
-    def poll(self, *args, **kwargs) -> List[Tuple[int, int]]:  # type: ignore[no-untyped-def]
-        # record the time since poll() was last called. This gives a good proxy for
+    def __call__(self, *args, **kwargs) -> Any:  # type: ignore[no-untyped-def]
+        # record the time since this was last called. This gives a good proxy for
         # how long it takes to run everything in the reactor - ie, how long anything
         # waiting for the next tick will have to wait.
         tick_time.observe(time.time() - self.last_polled)
 
-        ret = self._poller.poll(*args, **kwargs)
+        ret = self._wrapped(*args, **kwargs)
 
         self.last_polled = time.time()
         return ret
 
+
+class ObjWrapper:
+    """A wrapper for an object which wraps a specified method in CallWrapper.
+
+    Other methods/attributes are passed to the original object.
+
+    This is necessary when the wrapped object does not allow the attribute to be
+    overwritten.
+    """
+
+    def __init__(self, wrapped: Any, method_name: str):
+        self._wrapped = wrapped
+        self._method_name = method_name
+        self._wrapped_method = CallWrapper(getattr(wrapped, method_name))
+
     def __getattr__(self, item: str) -> Any:
-        return getattr(self._poller, item)
+        if item == self._method_name:
+            return self._wrapped_method
+
+        return getattr(self._wrapped, item)
 
 
 class ReactorLastSeenMetric(Collector):
-    def __init__(self, epoll_wrapper: EpollWrapper):
-        self._epoll_wrapper = epoll_wrapper
+    def __init__(self, call_wrapper: CallWrapper):
+        self._call_wrapper = call_wrapper
 
     def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily(
             "python_twisted_reactor_last_seen",
             "Seconds since the Twisted reactor was last seen",
         )
-        cm.add_metric([], time.time() - self._epoll_wrapper.last_polled)
+        cm.add_metric([], time.time() - self._call_wrapper.last_polled)
         yield cm
 
 
+# Twisted has already select a reasonable reactor for us, so assumptions can be
+# made about the shape.
+wrapper = None
 try:
-    # if the reactor has a `_poller` attribute, which is an `epoll` object
-    # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will
-    # measure the time between ticks
-    from select import epoll  # type: ignore[attr-defined]
-
-    poller = reactor._poller  # type: ignore[attr-defined]
-except (AttributeError, ImportError):
-    pass
-else:
-    if isinstance(poller, epoll):
-        poller = EpollWrapper(poller)
-        reactor._poller = poller  # type: ignore[attr-defined]
-        REGISTRY.register(ReactorLastSeenMetric(poller))
+    if isinstance(reactor, (PollReactor, EPollReactor)):
+        reactor._poller = ObjWrapper(reactor._poller, "poll")  # type: ignore[attr-defined]
+        wrapper = reactor._poller._wrapped_method  # type: ignore[attr-defined]
+
+    elif isinstance(reactor, selectreactor.SelectReactor):
+        # Twisted uses a module-level _select function.
+        wrapper = selectreactor._select = CallWrapper(selectreactor._select)
+
+    elif isinstance(reactor, AsyncioSelectorReactor):
+        # For asyncio look at the underlying asyncio event loop.
+        asyncio_loop = reactor._asyncioEventloop  # A sub-class of BaseEventLoop,
+
+        # A sub-class of BaseSelector.
+        selector = asyncio_loop._selector  # type: ignore[attr-defined]
+
+        if isinstance(selector, SelectSelector):
+            wrapper = selector._select = CallWrapper(selector._select)  # type: ignore[attr-defined]
+
+        # poll, epoll, and /dev/poll.
+        elif isinstance(selector, _PollLikeSelector):
+            selector._selector = ObjWrapper(selector._selector, "poll")  # type: ignore[attr-defined]
+            wrapper = selector._selector._wrapped_method  # type: ignore[attr-defined]
+
+        elif isinstance(selector, KqueueSelector):
+            selector._selector = ObjWrapper(selector._selector, "control")  # type: ignore[attr-defined]
+            wrapper = selector._selector._wrapped_method  # type: ignore[attr-defined]
+
+        else:
+            # E.g. this does not support the (Windows-only) ProactorEventLoop.
+            logger.warning(
+                "Skipping configuring ReactorLastSeenMetric: unexpected asyncio loop selector: %r via %r",
+                selector,
+                asyncio_loop,
+            )
+except Exception as e:
+    logger.warning("Configuring ReactorLastSeenMetric failed: %r", e)
+
+
+if wrapper:
+    REGISTRY.register(ReactorLastSeenMetric(wrapper))
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 755c59274c..812144a128 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1860,7 +1860,8 @@ class PublicRoomListManager:
         if not room:
             return False
 
-        return room.get("is_public", False)
+        # The first item is whether the room is public.
+        return room[0]
 
     async def add_room_to_public_room_list(self, room_id: str) -> None:
         """Publishes a room to the public room list.
diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
index ecaeef3511..7419785aff 100644
--- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
+++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
@@ -295,7 +295,8 @@ class ThirdPartyEventRulesModuleApiCallbacks:
                 raise
             except SynapseError as e:
                 # FIXME: Being able to throw SynapseErrors is relied upon by
-                # some modules. PR #10386 accidentally broke this ability.
+                # some modules. PR https://github.com/matrix-org/synapse/pull/10386
+                # accidentally broke this ability.
                 # That said, we aren't keen on exposing this implementation detail
                 # to modules and we should one day have a proper way to do what
                 # is wanted.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 14784312dc..5934b1ef34 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -25,10 +25,13 @@ from typing import (
     Sequence,
     Tuple,
     Union,
+    cast,
 )
 
 from prometheus_client import Counter
 
+from twisted.internet.defer import Deferred
+
 from synapse.api.constants import (
     MAIN_TIMELINE,
     EventContentFields,
@@ -40,11 +43,15 @@ from synapse.api.room_versions import PushRuleRoomFlag
 from synapse.event_auth import auth_types_for_event, get_user_power_level
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
+from synapse.storage.roommember import ProfileInfo
 from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
 from synapse.types import JsonValue
 from synapse.types.state import StateFilter
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import gather_results
 from synapse.util.caches import register_cache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_event_for_clients_with_state
@@ -342,15 +349,41 @@ class BulkPushRuleEvaluator:
         rules_by_user = await self._get_rules_for_event(event)
         actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
 
-        room_member_count = await self.store.get_number_joined_users_in_room(
-            event.room_id
-        )
-
+        # Gather a bunch of info in parallel.
+        #
+        # This has a lot of ignored types and casting due to the use of @cached
+        # decorated functions passed into run_in_background.
+        #
+        # See https://github.com/matrix-org/synapse/issues/16606
         (
-            power_levels,
-            sender_power_level,
-        ) = await self._get_power_levels_and_sender_level(
-            event, context, event_id_to_event
+            room_member_count,
+            (power_levels, sender_power_level),
+            related_events,
+            profiles,
+        ) = await make_deferred_yieldable(
+            cast(
+                "Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]",
+                gather_results(
+                    (
+                        run_in_background(  # type: ignore[call-arg]
+                            self.store.get_number_joined_users_in_room, event.room_id  # type: ignore[arg-type]
+                        ),
+                        run_in_background(
+                            self._get_power_levels_and_sender_level,
+                            event,
+                            context,
+                            event_id_to_event,
+                        ),
+                        run_in_background(self._related_events, event),
+                        run_in_background(  # type: ignore[call-arg]
+                            self.store.get_subset_users_in_room_with_profiles,
+                            event.room_id,  # type: ignore[arg-type]
+                            rules_by_user.keys(),  # type: ignore[arg-type]
+                        ),
+                    ),
+                    consumeErrors=True,
+                ).addErrback(unwrapFirstError),
+            )
         )
 
         # Find the event's thread ID.
@@ -366,8 +399,6 @@ class BulkPushRuleEvaluator:
                 # the parent is part of a thread.
                 thread_id = await self.store.get_thread_id(relation.parent_id)
 
-        related_events = await self._related_events(event)
-
         # It's possible that old room versions have non-integer power levels (floats or
         # strings; even the occasional `null`). For old rooms, we interpret these as if
         # they were integers. Do this here for the `@room` power level threshold.
@@ -400,11 +431,6 @@ class BulkPushRuleEvaluator:
             self.hs.config.experimental.msc1767_enabled,  # MSC3931 flag
         )
 
-        users = rules_by_user.keys()
-        profiles = await self.store.get_subset_users_in_room_with_profiles(
-            event.room_id, users
-        )
-
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index afd03137f0..c14a18ba2e 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -257,6 +257,11 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             self._notifier.add_lock_released_callback(self.on_lock_released)
 
+        # Marks if we should send POSITION commands for all streams ASAP. This
+        # is checked by the `ReplicationStreamer` which manages sending
+        # RDATA/POSITION commands
+        self._should_announce_positions = True
+
     def subscribe_to_channel(self, channel_name: str) -> None:
         """
         Indicates that we wish to subscribe to a Redis channel by name.
@@ -397,29 +402,23 @@ class ReplicationCommandHandler:
         return self._streams_to_replicate
 
     def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
-        self.send_positions_to_connection(conn)
+        self.send_positions_to_connection()
 
-    def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
+    def send_positions_to_connection(self) -> None:
         """Send current position of all streams this process is source of to
         the connection.
         """
 
-        # We respond with current position of all streams this instance
-        # replicates.
-        for stream in self.get_streams_to_replicate():
-            # Note that we use the current token as the prev token here (rather
-            # than stream.last_token), as we can't be sure that there have been
-            # no rows written between last token and the current token (since we
-            # might be racing with the replication sending bg process).
-            current_token = stream.current_token(self._instance_name)
-            self.send_command(
-                PositionCommand(
-                    stream.NAME,
-                    self._instance_name,
-                    current_token,
-                    current_token,
-                )
-            )
+        self._should_announce_positions = True
+        self._notifier.notify_replication()
+
+    def should_announce_positions(self) -> bool:
+        """Check if we should send POSITION commands for all streams ASAP."""
+        return self._should_announce_positions
+
+    def will_announce_positions(self) -> None:
+        """Mark that we're about to send POSITIONs out for all streams."""
+        self._should_announce_positions = False
 
     def on_USER_SYNC(
         self, conn: IReplicationConnection, cmd: UserSyncCommand
@@ -588,6 +587,21 @@ class ReplicationCommandHandler:
 
         logger.debug("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
+        # Check if we can early discard this position. We can only do so for
+        # connected streams.
+        stream = self._streams[cmd.stream_name]
+        if stream.can_discard_position(
+            cmd.instance_name, cmd.prev_token, cmd.new_token
+        ) and self.is_stream_connected(conn, cmd.stream_name):
+            logger.debug(
+                "Discarding redundant POSITION %s/%s %s %s",
+                cmd.instance_name,
+                cmd.stream_name,
+                cmd.prev_token,
+                cmd.new_token,
+            )
+            return
+
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
@@ -599,6 +613,18 @@ class ReplicationCommandHandler:
         """
         stream = self._streams[stream_name]
 
+        if stream.can_discard_position(
+            cmd.instance_name, cmd.prev_token, cmd.new_token
+        ) and self.is_stream_connected(conn, cmd.stream_name):
+            logger.debug(
+                "Discarding redundant POSITION %s/%s %s %s",
+                cmd.instance_name,
+                cmd.stream_name,
+                cmd.prev_token,
+                cmd.new_token,
+            )
+            return
+
         # We're about to go and catch up with the stream, so remove from set
         # of connected streams.
         for streams in self._streams_by_connection.values():
@@ -626,8 +652,9 @@ class ReplicationCommandHandler:
             # for why this can happen.
 
             logger.info(
-                "Fetching replication rows for '%s' between %i and %i",
+                "Fetching replication rows for '%s' / %s between %i and %i",
                 stream_name,
+                cmd.instance_name,
                 current_token,
                 cmd.new_token,
             )
@@ -657,6 +684,13 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
+    def is_stream_connected(
+        self, conn: IReplicationConnection, stream_name: str
+    ) -> bool:
+        """Return if stream has been successfully connected and is ready to
+        receive updates"""
+        return stream_name in self._streams_by_connection.get(conn, ())
+
     def on_REMOTE_SERVER_UP(
         self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
     ) -> None:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7e96145b3b..1fa37bb888 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -141,7 +141,7 @@ class RedisSubscriber(SubscriberProtocol):
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.synapse_handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection()
 
     def messageReceived(self, pattern: str, channel: str, message: str) -> None:
         """Received a message from redis."""
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 38abb5df54..d15828f2d3 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -123,7 +123,7 @@ class ReplicationStreamer:
 
         # We check up front to see if anything has actually changed, as we get
         # poked because of changes that happened on other instances.
-        if all(
+        if not self.command_handler.should_announce_positions() and all(
             stream.last_token == stream.current_token(self._instance_name)
             for stream in self.streams
         ):
@@ -158,6 +158,21 @@ class ReplicationStreamer:
                         all_streams = list(all_streams)
                         random.shuffle(all_streams)
 
+                    if self.command_handler.should_announce_positions():
+                        # We need to send out POSITIONs for all streams, usually
+                        # because a worker has reconnected.
+                        self.command_handler.will_announce_positions()
+
+                        for stream in all_streams:
+                            self.command_handler.send_command(
+                                PositionCommand(
+                                    stream.NAME,
+                                    self._instance_name,
+                                    stream.last_token,
+                                    stream.last_token,
+                                )
+                            )
+
                     for stream in all_streams:
                         if stream.last_token == stream.current_token(
                             self._instance_name
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 58a44029aa..cc34dfb322 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -144,6 +144,16 @@ class Stream:
         """
         raise NotImplementedError()
 
+    def can_discard_position(
+        self, instance_name: str, prev_token: int, new_token: int
+    ) -> bool:
+        """Whether or not a position command for this stream can be discarded.
+
+        Useful for streams that can never go backwards and where we already know
+        the stream ID for the instance has advanced.
+        """
+        return False
+
     def discard_updates_and_advance(self) -> None:
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
@@ -221,6 +231,14 @@ class _StreamFromIdGen(Stream):
     def minimal_local_current_token(self) -> Token:
         return self._stream_id_gen.get_minimal_local_current_token()
 
+    def can_discard_position(
+        self, instance_name: str, prev_token: int, new_token: int
+    ) -> bool:
+        # These streams can't go backwards, so we know we can ignore any
+        # positions where the tokens are from before the current token.
+
+        return new_token <= self.current_token(instance_name)
+
 
 def current_token_without_instance(
     current_token: Callable[[], int]
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9bd0d764f8..91edfd45d7 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -88,6 +88,7 @@ from synapse.rest.admin.users import (
     UserByThreePid,
     UserMembershipRestServlet,
     UserRegisterServlet,
+    UserReplaceMasterCrossSigningKeyRestServlet,
     UserRestServletV2,
     UsersRestServletV2,
     UserTokenRestServlet,
@@ -292,6 +293,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ListDestinationsRestServlet(hs).register(http_server)
     RoomMessagesRestServlet(hs).register(http_server)
     RoomTimestampToEventRestServlet(hs).register(http_server)
+    UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server)
     UserByExternalId(hs).register(http_server)
     UserByThreePid(hs).register(http_server)
 
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index b7637dff0b..8cf5268854 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,6 +17,8 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
@@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet):
             start, limit, user_id, order_by, direction
         )
 
-        ret = {"media": media, "total": total}
+        ret = {"media": [attr.asdict(m) for m in media], "total": total}
         if (start + limit) < total:
             ret["next_token"] = start + len(media)
 
@@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet):
         )
 
         deleted_media, total = await self.media_repository.delete_local_media_ids(
-            [row["media_id"] for row in media]
+            [m.media_id for m in media]
         )
 
         return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index ffce92d45e..f3e06d3da3 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
         valid = parse_boolean(request, "valid")
         token_list = await self.store.get_registration_tokens(valid)
-        return HTTPStatus.OK, {"registration_tokens": token_list}
+        return HTTPStatus.OK, {
+            "registration_tokens": [
+                {
+                    "token": t[0],
+                    "uses_allowed": t[1],
+                    "pending": t[2],
+                    "completed": t[3],
+                    "expiry_time": t[4],
+                }
+                for t in token_list
+            ]
+        }
 
 
 class NewRegistrationTokenRestServlet(RestServlet):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0659f22a89..7e40bea8aa 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,6 +16,8 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 from urllib import parse as urlparse
 
+import attr
+
 from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.api.filtering import Filter
@@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet):
             raise NotFoundError("Room not found")
 
         members = await self.store.get_users_in_room(room_id)
-        ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
-        ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
+        result = attr.asdict(ret)
+        result["joined_local_devices"] = await self.store.count_devices_by_users(
+            members
+        )
+        result["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
 
-        return HTTPStatus.OK, ret
+        return HTTPStatus.OK, result
 
     async def on_DELETE(
         self, request: SynapseRequest, room_id: str
@@ -408,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        ret = await self.store.get_room(room_id)
-        if not ret:
+        room = await self.store.get_room(room_id)
+        if not room:
             raise NotFoundError("Room not found")
 
         members = await self.store.get_users_in_room(room_id)
@@ -437,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        ret = await self.store.get_room(room_id)
-        if not ret:
+        room = await self.store.get_room(room_id)
+        if not room:
             raise NotFoundError("Room not found")
 
         event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7fe16130e7..9900498fbe 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -18,6 +18,8 @@ import secrets
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import Direction, UserTypes
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
@@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet):
         )
 
         # If support for MSC3866 is not enabled, don't show the approval flag.
+        filter = None
         if not self._msc3866_enabled:
-            for user in users:
-                del user["approved"]
 
-        ret = {"users": users, "total": total}
+            def _filter(a: attr.Attribute) -> bool:
+                return a.name != "approved"
+
+        ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total}
         if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
@@ -1266,6 +1270,46 @@ class AccountDataRestServlet(RestServlet):
         }
 
 
+class UserReplaceMasterCrossSigningKeyRestServlet(RestServlet):
+    """Allow a given user to replace their master cross-signing key without UIA.
+
+    This replacement is permitted for a limited period (currently 10 minutes).
+
+    While this is exposed via the admin API, this is intended for use by the
+    Matrix Authentication Service rather than server admins.
+    """
+
+    PATTERNS = admin_patterns(
+        "/users/(?P<user_id>[^/]*)/_allow_cross_signing_replacement_without_uia"
+    )
+    REPLACEMENT_PERIOD_MS = 10 * 60 * 1000  # 10 minutes
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastores().main
+
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self._auth, request)
+
+        if user_id is None:
+            raise NotFoundError("User not found")
+
+        timestamp = (
+            await self._store.allow_master_cross_signing_key_replacement_without_uia(
+                user_id, self.REPLACEMENT_PERIOD_MS
+            )
+        )
+
+        if timestamp is None:
+            raise NotFoundError("User has no master cross-signing key")
+
+        return HTTPStatus.OK, {"updatable_without_uia_before_ms": timestamp}
+
+
 class UserByExternalId(RestServlet):
     """Find a user based on an external ID from an auth provider"""
 
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 641390cb30..0c0e82627d 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -299,19 +299,16 @@ class DeactivateAccountRestServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request)
 
-        # allow ASes to deactivate their own users
-        if requester.app_service:
-            await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), body.erase, requester
+        # allow ASes to deactivate their own users:
+        # ASes don't need user-interactive auth
+        if not requester.app_service:
+            await self.auth_handler.validate_user_via_ui_auth(
+                requester,
+                request,
+                body.dict(exclude_unset=True),
+                "deactivate your account",
             )
-            return 200, {}
 
-        await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body.dict(exclude_unset=True),
-            "deactivate your account",
-        )
         result = await self._deactivate_account_handler.deactivate_account(
             requester.user.to_string(), body.erase, requester, id_server=body.id_server
         )
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 82944ca711..3534c3c259 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
         if room is None:
             raise NotFoundError("Unknown room")
 
-        return 200, {"visibility": "public" if room["is_public"] else "private"}
+        return 200, {"visibility": "public" if room[0] else "private"}
 
     class PutBody(RequestBodyModel):
         visibility: Literal["public", "private"] = "public"
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 70b8be1aa2..add8045439 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -376,9 +376,10 @@ class SigningKeyUploadServlet(RestServlet):
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
 
-        is_cross_signing_setup = (
-            await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id)
-        )
+        (
+            is_cross_signing_setup,
+            master_key_updatable_without_uia,
+        ) = await self.e2e_keys_handler.check_cross_signing_setup(user_id)
 
         # Before MSC3967 we required UIA both when setting up cross signing for the
         # first time and when resetting the device signing key. With MSC3967 we only
@@ -386,9 +387,14 @@ class SigningKeyUploadServlet(RestServlet):
         # time. Because there is no UIA in MSC3861, for now we throw an error if the
         # user tries to reset the device signing key when MSC3861 is enabled, but allow
         # first-time setup.
+        #
+        # XXX: We now have a get-out clause by which MAS can temporarily mark the master
+        # key as replaceable. It should do its own equivalent of user interactive auth
+        # before doing so.
         if self.hs.config.experimental.msc3861.enabled:
-            # There is no way to reset the device signing key with MSC3861
-            if is_cross_signing_setup:
+            # The auth service has to explicitly mark the master key as replaceable
+            # without UIA to reset the device signing key with MSC3861.
+            if is_cross_signing_setup and not master_key_updatable_without_uia:
                 raise SynapseError(
                     HTTPStatus.NOT_IMPLEMENTED,
                     "Resetting cross signing keys is not yet supported with MSC3861",
diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py
new file mode 100644
index 0000000000..994afdf13c
--- /dev/null
+++ b/synapse/rest/media/create_resource.py
@@ -0,0 +1,83 @@
+# Copyright 2023 Beeper Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+from typing import TYPE_CHECKING
+
+from synapse.api.errors import LimitExceededError
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.http.server import respond_with_json
+from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.media.media_repository import MediaRepository
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class CreateResource(RestServlet):
+    PATTERNS = [re.compile("/_matrix/media/v1/create")]
+
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
+        super().__init__()
+
+        self.media_repo = media_repo
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
+
+        # A rate limiter for creating new media IDs.
+        self._create_media_rate_limiter = Ratelimiter(
+            store=hs.get_datastores().main,
+            clock=self.clock,
+            cfg=hs.config.ratelimiting.rc_media_create,
+        )
+
+    async def on_POST(self, request: SynapseRequest) -> None:
+        requester = await self.auth.get_user_by_req(request)
+
+        # If the create media requests for the user are over the limit, drop them.
+        await self._create_media_rate_limiter.ratelimit(requester)
+
+        (
+            reached_pending_limit,
+            first_expiration_ts,
+        ) = await self.media_repo.reached_pending_media_limit(requester.user)
+        if reached_pending_limit:
+            raise LimitExceededError(
+                limiter_name="max_pending_media_uploads",
+                retry_after_ms=first_expiration_ts - self.clock.time_msec(),
+            )
+
+        content_uri, unused_expires_at = await self.media_repo.create_media_id(
+            requester.user
+        )
+
+        logger.info(
+            "Created Media URI %r that if unused will expire at %d",
+            content_uri,
+            unused_expires_at,
+        )
+        respond_with_json(
+            request,
+            200,
+            {
+                "content_uri": content_uri,
+                "unused_expires_at": unused_expires_at,
+            },
+            send_cors=True,
+        )
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 65b9ff52fa..60cd87548c 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -17,9 +17,13 @@ import re
 from typing import TYPE_CHECKING, Optional
 
 from synapse.http.server import set_corp_headers, set_cors_headers
-from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
 from synapse.http.site import SynapseRequest
-from synapse.media._base import respond_404
+from synapse.media._base import (
+    DEFAULT_MAX_TIMEOUT_MS,
+    MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
+    respond_404,
+)
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
@@ -65,12 +69,16 @@ class DownloadResource(RestServlet):
         )
         # Limited non-standard form of CSP for IE11
         request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
-        request.setHeader(
-            b"Referrer-Policy",
-            b"no-referrer",
+        request.setHeader(b"Referrer-Policy", b"no-referrer")
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
         )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
         if self._is_mine_server_name(server_name):
-            await self.media_repo.get_local_media(request, media_id, file_name)
+            await self.media_repo.get_local_media(
+                request, media_id, file_name, max_timeout_ms
+            )
         else:
             allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
@@ -83,5 +91,5 @@ class DownloadResource(RestServlet):
                 return
 
             await self.media_repo.get_remote_media(
-                request, server_name, media_id, file_name
+                request, server_name, media_id, file_name, max_timeout_ms
             )
diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py
index 2089bb1029..ca65116b84 100644
--- a/synapse/rest/media/media_repository_resource.py
+++ b/synapse/rest/media/media_repository_resource.py
@@ -18,10 +18,11 @@ from synapse.config._base import ConfigError
 from synapse.http.server import HttpServer, JsonResource
 
 from .config_resource import MediaConfigResource
+from .create_resource import CreateResource
 from .download_resource import DownloadResource
 from .preview_url_resource import PreviewUrlResource
 from .thumbnail_resource import ThumbnailResource
-from .upload_resource import UploadResource
+from .upload_resource import AsyncUploadServlet, UploadServlet
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource):
 
         # Note that many of these should not exist as v1 endpoints, but empirically
         # a lot of traffic still goes to them.
-
-        UploadResource(hs, media_repo).register(http_server)
+        CreateResource(hs, media_repo).register(http_server)
+        UploadServlet(hs, media_repo).register(http_server)
+        AsyncUploadServlet(hs, media_repo).register(http_server)
         DownloadResource(hs, media_repo).register(http_server)
         ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
             http_server
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..681f2a5a27 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.media._base import (
+    DEFAULT_MAX_TIMEOUT_MS,
+    MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
     FileInfo,
     ThumbnailInfo,
     respond_404,
@@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet):
         method = parse_string(request, "method", "scale")
         # TODO Parse the Accept header to get an prioritised list of thumbnail types.
         m_type = "image/png"
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+        )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
 
         if self._is_mine_server_name(server_name):
             if self.dynamic_thumbnails:
                 await self._select_or_generate_local_thumbnail(
-                    request, media_id, width, height, method, m_type
+                    request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             else:
                 await self._respond_local_thumbnail(
-                    request, media_id, width, height, method, m_type
+                    request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             self.media_repo.mark_recently_accessed(None, media_id)
         else:
@@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
-            if self.dynamic_thumbnails:
-                await self._select_or_generate_remote_thumbnail(
-                    request, server_name, media_id, width, height, method, m_type
-                )
-            else:
-                await self._respond_remote_thumbnail(
-                    request, server_name, media_id, width, height, method, m_type
-                )
+            remote_resp_function = (
+                self._select_or_generate_remote_thumbnail
+                if self.dynamic_thumbnails
+                else self._respond_remote_thumbnail
+            )
+            await remote_resp_function(
+                request,
+                server_name,
+                media_id,
+                width,
+                height,
+                method,
+                m_type,
+                max_timeout_ms,
+            )
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
@@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet):
         height: int,
         method: str,
         m_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.store.get_local_media(media_id)
-
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
         if not media_info:
-            respond_404(request)
-            return
-        if media_info["quarantined_by"]:
-            logger.info("Media is quarantined")
-            respond_404(request)
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -134,7 +144,7 @@ class ThumbnailResource(RestServlet):
             thumbnail_infos,
             media_id,
             media_id,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
             server_name=None,
         )
 
@@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet):
         desired_height: int,
         desired_method: str,
         desired_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.store.get_local_media(media_id)
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
 
         if not media_info:
-            respond_404(request)
-            return
-        if media_info["quarantined_by"]:
-            logger.info("Media is quarantined")
-            respond_404(request)
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -168,7 +176,7 @@ class ThumbnailResource(RestServlet):
                 file_info = FileInfo(
                     server_name=None,
                     file_id=media_id,
-                    url_cache=media_info["url_cache"],
+                    url_cache=bool(media_info.url_cache),
                     thumbnail=info,
                 )
 
@@ -188,7 +196,7 @@ class ThumbnailResource(RestServlet):
             desired_height,
             desired_method,
             desired_type,
-            url_cache=bool(media_info["url_cache"]),
+            url_cache=bool(media_info.url_cache),
         )
 
         if file_path:
@@ -206,14 +214,20 @@ class ThumbnailResource(RestServlet):
         desired_height: int,
         desired_method: str,
         desired_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            respond_404(request)
+            return
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
 
-        file_id = media_info["filesystem_id"]
+        file_id = media_info.filesystem_id
 
         for info in thumbnail_infos:
             t_w = info.width == desired_width
@@ -224,7 +238,7 @@ class ThumbnailResource(RestServlet):
             if t_w and t_h and t_method and t_type:
                 file_info = FileInfo(
                     server_name=server_name,
-                    file_id=media_info["filesystem_id"],
+                    file_id=file_id,
                     thumbnail=info,
                 )
 
@@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet):
         height: int,
         method: str,
         m_type: str,
+        max_timeout_ms: int,
     ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
-        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            return
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
@@ -280,7 +299,7 @@ class ThumbnailResource(RestServlet):
             m_type,
             thumbnail_infos,
             media_id,
-            media_info["filesystem_id"],
+            media_info.filesystem_id,
             url_cache=False,
             server_name=server_name,
         )
diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py
index 949326d85d..62d3e228a8 100644
--- a/synapse/rest/media/upload_resource.py
+++ b/synapse/rest/media/upload_resource.py
@@ -15,7 +15,7 @@
 
 import logging
 import re
-from typing import IO, TYPE_CHECKING, Dict, List, Optional
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import respond_with_json
@@ -29,23 +29,24 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# The name of the lock to use when uploading media.
+_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
 
-class UploadResource(RestServlet):
-    PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
 
+class BaseUploadServlet(RestServlet):
     def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
         self.filepaths = media_repo.filepaths
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
         self.auth = hs.get_auth()
         self.max_upload_size = hs.config.media.max_upload_size
-        self.clock = hs.get_clock()
 
-    async def on_POST(self, request: SynapseRequest) -> None:
-        requester = await self.auth.get_user_by_req(request)
+    def _get_file_metadata(
+        self, request: SynapseRequest
+    ) -> Tuple[int, Optional[str], str]:
         raw_content_length = request.getHeader("Content-Length")
         if raw_content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
@@ -88,6 +89,16 @@ class UploadResource(RestServlet):
         #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
+        return content_length, upload_name, media_type
+
+
+class UploadServlet(BaseUploadServlet):
+    PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")]
+
+    async def on_POST(self, request: SynapseRequest) -> None:
+        requester = await self.auth.get_user_by_req(request)
+        content_length, upload_name, media_type = self._get_file_metadata(request)
+
         try:
             content: IO = request.content  # type: ignore
             content_uri = await self.media_repo.create_content(
@@ -103,3 +114,53 @@ class UploadResource(RestServlet):
         respond_with_json(
             request, 200, {"content_uri": str(content_uri)}, send_cors=True
         )
+
+
+class AsyncUploadServlet(BaseUploadServlet):
+    PATTERNS = [
+        re.compile(
+            "/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+        )
+    ]
+
+    async def on_PUT(
+        self, request: SynapseRequest, server_name: str, media_id: str
+    ) -> None:
+        requester = await self.auth.get_user_by_req(request)
+
+        if server_name != self.server_name:
+            raise SynapseError(
+                404,
+                "Non-local server name specified",
+                errcode=Codes.NOT_FOUND,
+            )
+
+        lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
+        if not lock:
+            raise SynapseError(
+                409,
+                "Media ID cannot be overwritten",
+                errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+            )
+
+        async with lock:
+            await self.media_repo.verify_can_upload(media_id, requester.user)
+            content_length, upload_name, media_type = self._get_file_metadata(request)
+
+            try:
+                content: IO = request.content  # type: ignore
+                await self.media_repo.update_content(
+                    media_id,
+                    media_type,
+                    upload_name,
+                    content,
+                    content_length,
+                    requester.user,
+                )
+            except SpamMediaException:
+                # For uploading of media we want to respond with a 400, instead of
+                # the default 404, as that would just be confusing.
+                raise SynapseError(400, "Bad content")
+
+            logger.info("Uploaded content for media ID %r", media_id)
+            respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 12829d3d7d..62fbd05534 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -28,6 +28,7 @@ from typing import (
     Sequence,
     Tuple,
     Type,
+    cast,
 )
 
 import attr
@@ -48,7 +49,11 @@ else:
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.database import DatabasePool, LoggingTransaction
+    from synapse.storage.database import (
+        DatabasePool,
+        LoggingDatabaseConnection,
+        LoggingTransaction,
+    )
 
 logger = logging.getLogger(__name__)
 
@@ -488,14 +493,14 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
+        def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
                 ORDER BY ordering, update_name
                 """
             )
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(List[Tuple[str, Optional[str]]], txn.fetchall())
 
         if not self._current_background_update:
             all_pending_updates = await self.db_pool.runInteraction(
@@ -507,14 +512,13 @@ class BackgroundUpdater:
                 return True
 
             # find the first update which isn't dependent on another one in the queue.
-            pending = {update["update_name"] for update in all_pending_updates}
-            for upd in all_pending_updates:
-                depends_on = upd["depends_on"]
+            pending = {update_name for update_name, depends_on in all_pending_updates}
+            for update_name, depends_on in all_pending_updates:
                 if not depends_on or depends_on not in pending:
                     break
                 logger.info(
                     "Not starting on bg update %s until %s is done",
-                    upd["update_name"],
+                    update_name,
                     depends_on,
                 )
             else:
@@ -524,7 +528,7 @@ class BackgroundUpdater:
                     "another: dependency cycle?"
                 )
 
-            self._current_background_update = upd["update_name"]
+            self._current_background_update = update_name
 
         # We have a background update to run, otherwise we would have returned
         # early.
@@ -746,10 +750,10 @@ class BackgroundUpdater:
                 The named index will be dropped upon completion of the new index.
         """
 
-        def create_index_psql(conn: Connection) -> None:
+        def create_index_psql(conn: "LoggingDatabaseConnection") -> None:
             conn.rollback()
             # postgres insists on autocommit for the index
-            conn.set_session(autocommit=True)  # type: ignore
+            conn.engine.attempt_to_set_autocommit(conn.conn, True)
 
             try:
                 c = conn.cursor()
@@ -793,9 +797,9 @@ class BackgroundUpdater:
                 undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
                 conn.cursor().execute(undo_timeout_sql)
 
-                conn.set_session(autocommit=False)  # type: ignore
+                conn.engine.attempt_to_set_autocommit(conn.conn, False)
 
-        def create_index_sqlite(conn: Connection) -> None:
+        def create_index_sqlite(conn: "LoggingDatabaseConnection") -> None:
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We assume that sqlite doesn't give us invalid indices; however
@@ -825,7 +829,9 @@ class BackgroundUpdater:
                 c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner: Optional[Callable[[Connection], None]] = create_index_psql
+            runner: Optional[
+                Callable[[LoggingDatabaseConnection], None]
+            ] = create_index_psql
         elif psql_only:
             runner = None
         else:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController:
         return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
-        self, _room_id: str, task: _PersistEventsTask
+        self, room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
         Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Returns:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
@@ -594,140 +596,23 @@ class EventsPersistenceStorageController:
             # We can't easily parallelize these since different chunks
             # might contain the same event. :(
 
-            # NB: Assumes that we are only persisting events for one room
-            # at a time.
-
-            # map room_id->set[event_ids] giving the new forward
-            # extremities in each room
-            new_forward_extremities: Dict[str, Set[str]] = {}
-
-            # map room_id->(to_delete, to_insert) where to_delete is a list
-            # of type/state keys to remove from current state, and to_insert
-            # is a map (type,key)->event_id giving the state delta in each
-            # room
-            state_delta_for_room: Dict[str, DeltaState] = {}
+            new_forward_extremities = None
+            state_delta_for_room = None
 
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
-                    # Work out the new "current state" for each room.
+                    # Work out the new "current state" for the room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = (
-                            await self.main_store.get_latest_event_ids_in_room(room_id)
-                        )
-                        new_latest_event_ids = await self._calculate_new_extremities(
-                            room_id, ev_ctx_rm, latest_event_ids
-                        )
-
-                        if new_latest_event_ids == latest_event_ids:
-                            # No change in extremities, so no change in state
-                            continue
-
-                        # there should always be at least one forward extremity.
-                        # (except during the initial persistence of the send_join
-                        # results, in which case there will be no existing
-                        # extremities, so we'll `continue` above and skip this bit.)
-                        assert new_latest_event_ids, "No forward extremities left!"
-
-                        new_forward_extremities[room_id] = new_latest_event_ids
-
-                        len_1 = (
-                            len(latest_event_ids) == 1
-                            and len(new_latest_event_ids) == 1
-                        )
-                        if len_1:
-                            all_single_prev_not_state = all(
-                                len(event.prev_event_ids()) == 1
-                                and not event.is_state()
-                                for event, ctx in ev_ctx_rm
-                            )
-                            # Don't bother calculating state if they're just
-                            # a long chain of single ancestor non-state events.
-                            if all_single_prev_not_state:
-                                continue
-
-                        state_delta_counter.inc()
-                        if len(new_latest_event_ids) == 1:
-                            state_delta_single_event_counter.inc()
-
-                            # This is a fairly handwavey check to see if we could
-                            # have guessed what the delta would have been when
-                            # processing one of these events.
-                            # What we're interested in is if the latest extremities
-                            # were the same when we created the event as they are
-                            # now. When this server creates a new event (as opposed
-                            # to receiving it over federation) it will use the
-                            # forward extremities as the prev_events, so we can
-                            # guess this by looking at the prev_events and checking
-                            # if they match the current forward extremities.
-                            for ev, _ in ev_ctx_rm:
-                                prev_event_ids = set(ev.prev_event_ids())
-                                if latest_event_ids == prev_event_ids:
-                                    state_delta_reuse_delta_counter.inc()
-                                    break
-
-                        logger.debug("Calculating state delta for room %s", room_id)
-                        with Measure(
-                            self._clock, "persist_events.get_new_state_after_events"
-                        ):
-                            res = await self._get_new_state_after_events(
-                                room_id,
-                                ev_ctx_rm,
-                                latest_event_ids,
-                                new_latest_event_ids,
-                            )
-                            current_state, delta_ids, new_latest_event_ids = res
-
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremities[room_id] = new_latest_event_ids
-
-                        # If either are not None then there has been a change,
-                        # and we need to work out the delta (or use that
-                        # given)
-                        delta = None
-                        if delta_ids is not None:
-                            # If there is a delta we know that we've
-                            # only added or replaced state, never
-                            # removed keys entirely.
-                            delta = DeltaState([], delta_ids)
-                        elif current_state is not None:
-                            with Measure(
-                                self._clock, "persist_events.calculate_state_delta"
-                            ):
-                                delta = await self._calculate_state_delta(
-                                    room_id, current_state
-                                )
-
-                        if delta:
-                            # If we have a change of state then lets check
-                            # whether we're actually still a member of the room,
-                            # or if our last user left. If we're no longer in
-                            # the room then we delete the current state and
-                            # extremities.
-                            is_still_joined = await self._is_server_still_joined(
-                                room_id,
-                                ev_ctx_rm,
-                                delta,
-                            )
-                            if not is_still_joined:
-                                logger.info("Server no longer in room %s", room_id)
-                                delta.no_longer_in_room = True
-
-                            state_delta_for_room[room_id] = delta
+                    (
+                        new_forward_extremities,
+                        state_delta_for_room,
+                    ) = await self._calculate_new_forward_extremities_and_state_delta(
+                        room_id, chunk
+                    )
 
             await self.persist_events_store._persist_events_and_state_updates(
+                room_id,
                 chunk,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
@@ -737,6 +622,117 @@ class EventsPersistenceStorageController:
 
         return replaced_events
 
+    async def _calculate_new_forward_extremities_and_state_delta(
+        self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]]
+    ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
+        """Calculates the new forward extremities and state delta for a room
+        given events to persist.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            A tuple of:
+                A set of str giving the new forward extremities the room
+
+                The state delta for the room.
+        """
+
+        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+        new_latest_event_ids = await self._calculate_new_extremities(
+            room_id, ev_ctx_rm, latest_event_ids
+        )
+
+        if new_latest_event_ids == latest_event_ids:
+            # No change in extremities, so no change in state
+            return (None, None)
+
+        # there should always be at least one forward extremity.
+        # (except during the initial persistence of the send_join
+        # results, in which case there will be no existing
+        # extremities, so we'll `continue` above and skip this bit.)
+        assert new_latest_event_ids, "No forward extremities left!"
+
+        new_forward_extremities = new_latest_event_ids
+
+        len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1
+        if len_1:
+            all_single_prev_not_state = all(
+                len(event.prev_event_ids()) == 1 and not event.is_state()
+                for event, ctx in ev_ctx_rm
+            )
+            # Don't bother calculating state if they're just
+            # a long chain of single ancestor non-state events.
+            if all_single_prev_not_state:
+                return (new_forward_extremities, None)
+
+        state_delta_counter.inc()
+        if len(new_latest_event_ids) == 1:
+            state_delta_single_event_counter.inc()
+
+            # This is a fairly handwavey check to see if we could
+            # have guessed what the delta would have been when
+            # processing one of these events.
+            # What we're interested in is if the latest extremities
+            # were the same when we created the event as they are
+            # now. When this server creates a new event (as opposed
+            # to receiving it over federation) it will use the
+            # forward extremities as the prev_events, so we can
+            # guess this by looking at the prev_events and checking
+            # if they match the current forward extremities.
+            for ev, _ in ev_ctx_rm:
+                prev_event_ids = set(ev.prev_event_ids())
+                if latest_event_ids == prev_event_ids:
+                    state_delta_reuse_delta_counter.inc()
+                    break
+
+        logger.debug("Calculating state delta for room %s", room_id)
+        with Measure(self._clock, "persist_events.get_new_state_after_events"):
+            res = await self._get_new_state_after_events(
+                room_id,
+                ev_ctx_rm,
+                latest_event_ids,
+                new_latest_event_ids,
+            )
+            current_state, delta_ids, new_latest_event_ids = res
+
+            # there should always be at least one forward extremity.
+            # (except during the initial persistence of the send_join
+            # results, in which case there will be no existing
+            # extremities, so we'll `continue` above and skip this bit.)
+            assert new_latest_event_ids, "No forward extremities left!"
+
+            new_forward_extremities = new_latest_event_ids
+
+        # If either are not None then there has been a change,
+        # and we need to work out the delta (or use that
+        # given)
+        delta = None
+        if delta_ids is not None:
+            # If there is a delta we know that we've
+            # only added or replaced state, never
+            # removed keys entirely.
+            delta = DeltaState([], delta_ids)
+        elif current_state is not None:
+            with Measure(self._clock, "persist_events.calculate_state_delta"):
+                delta = await self._calculate_state_delta(room_id, current_state)
+
+        if delta:
+            # If we have a change of state then lets check
+            # whether we're actually still a member of the room,
+            # or if our last user left. If we're no longer in
+            # the room then we delete the current state and
+            # extremities.
+            is_still_joined = await self._is_server_still_joined(
+                room_id,
+                ev_ctx_rm,
+                delta,
+            )
+            if not is_still_joined:
+                logger.info("Server no longer in room %s", room_id)
+                delta.no_longer_in_room = True
+
+        return (new_forward_extremities, delta)
+
     async def _calculate_new_extremities(
         self,
         room_id: str,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a4e7048368..eb34de4df5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -18,7 +18,6 @@ import logging
 import time
 import types
 from collections import defaultdict
-from sys import intern
 from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
@@ -1042,20 +1041,6 @@ class DatabasePool:
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-    @staticmethod
-    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
-        """Converts a SQL cursor into an list of dicts.
-
-        Args:
-            cursor: The DBAPI cursor which has executed a query.
-        Returns:
-            A list of dicts where the key is the column header.
-        """
-        assert cursor.description is not None, "cursor.description was None"
-        col_headers = [intern(str(column[0])) for column in cursor.description]
-        results = [dict(zip(col_headers, row)) for row in cursor]
-        return results
-
     async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
         """Runs a single query for a result set.
 
@@ -1131,8 +1116,8 @@ class DatabasePool:
     def simple_insert_many_txn(
         txn: LoggingTransaction,
         table: str,
-        keys: Collection[str],
-        values: Iterable[Iterable[Any]],
+        keys: Sequence[str],
+        values: Collection[Iterable[Any]],
     ) -> None:
         """Executes an INSERT query on the named table.
 
@@ -1145,6 +1130,9 @@ class DatabasePool:
             keys: list of column names
             values: for each row, a list of values in the same order as `keys`
         """
+        # If there's nothing to insert, then skip executing the query.
+        if not values:
+            return
 
         if isinstance(txn.database_engine, PostgresEngine):
             # We use `execute_values` as it can be a lot faster than `execute_batch`,
@@ -1416,12 +1404,12 @@ class DatabasePool:
             allvalues.update(values)
             latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
 
-        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
             ", ".join(k for k in keyvalues),
-            f"WHERE {where_clause}" if where_clause else "",
+            f"WHERE {where_clause} " if where_clause else "",
             latter,
         )
         txn.execute(sql, list(allvalues.values()))
@@ -1470,7 +1458,7 @@ class DatabasePool:
         key_names: Collection[str],
         key_values: Collection[Iterable[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[Any]],
+        value_values: Collection[Iterable[Any]],
     ) -> None:
         """
         Upsert, many times.
@@ -1483,6 +1471,19 @@ class DatabasePool:
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
         """
+        # If there's nothing to upsert, then skip executing the query.
+        if not key_values:
+            return
+
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+        elif len(value_values) != len(key_values):
+            raise ValueError(
+                f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
+            )
+
         if table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
                 txn, table, key_names, key_values, value_names, value_values
@@ -1517,10 +1518,6 @@ class DatabasePool:
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
         """
-        # No value columns, therefore make a blank list so that the following
-        # zip() works correctly.
-        if not value_names:
-            value_values = [() for x in range(len(key_values))]
 
         # Lock the table just once, to prevent it being done once per row.
         # Note that, according to Postgres' documentation, once obtained,
@@ -1558,10 +1555,7 @@ class DatabasePool:
         allnames.extend(value_names)
 
         if not value_names:
-            # No value columns, therefore make a blank list so that the
-            # following zip() works correctly.
             latter = "NOTHING"
-            value_values = [() for x in range(len(key_values))]
         else:
             latter = "UPDATE SET " + ", ".join(
                 k + "=EXCLUDED." + k for k in value_names
@@ -1603,7 +1597,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one",
-    ) -> Dict[str, Any]:
+    ) -> Tuple[Any, ...]:
         ...
 
     @overload
@@ -1614,7 +1608,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one",
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         ...
 
     async def simple_select_one(
@@ -1624,7 +1618,7 @@ class DatabasePool:
         retcols: Collection[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
@@ -1925,6 +1919,7 @@ class DatabasePool:
         Returns:
             The results as a list of tuples.
         """
+        # If there's nothing to select, then skip executing the query.
         if not iterable:
             return []
 
@@ -2059,13 +2054,13 @@ class DatabasePool:
             raise ValueError(
                 f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
             )
+        # If there is nothing to update, then skip executing the query.
+        if not key_values:
+            return
 
         # List of tuples of (value values, then key values)
         # (This matches the order needed for the query)
-        args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
-
-        for ks, vs in zip(key_values, value_values):
-            args.append(tuple(vs) + tuple(ks))
+        args = [tuple(vv) + tuple(kv) for vv, kv in zip(value_values, key_values)]
 
         # 'col1 = ?, col2 = ?, ...'
         set_clause = ", ".join(f"{n} = ?" for n in value_names)
@@ -2077,9 +2072,7 @@ class DatabasePool:
             where_clause = ""
 
         # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
-        sql = f"""
-            UPDATE {table} SET {set_clause} {where_clause}
-        """
+        sql = f"UPDATE {table} SET {set_clause} {where_clause}"
 
         txn.execute_batch(sql, args)
 
@@ -2134,7 +2127,7 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         retcols: Collection[str],
         allow_none: bool = False,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[Tuple[Any, ...]]:
         select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
 
         if keyvalues:
@@ -2152,7 +2145,7 @@ class DatabasePool:
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-        return dict(zip(retcols, row))
+        return row
 
     async def simple_delete_one(
         self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
@@ -2295,11 +2288,10 @@ class DatabasePool:
         Returns:
             Number rows deleted
         """
+        # If there's nothing to delete, then skip executing the query.
         if not values:
             return 0
 
-        sql = "DELETE FROM %s" % table
-
         clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
         clauses = [clause]
 
@@ -2307,8 +2299,7 @@ class DatabasePool:
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses))
         txn.execute(sql, values)
 
         return txn.rowcount
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 7aa24ccf21..b57e260fe0 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -45,7 +45,7 @@ class Databases(Generic[DataStoreT]):
     """
 
     databases: List[DatabasePool]
-    main: "DataStore"  # FIXME: #11165: actually an instance of `main_store_class`
+    main: "DataStore"  # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
     state: StateGroupDataStore
     persist_events: Optional[PersistEventsStore]
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 840d725114..89f4077351 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
 
+import attr
+
 from synapse.api.constants import Direction
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage._base import make_in_list_sql_clause
@@ -28,7 +30,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import get_domain_from_id
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -82,6 +84,25 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPaginateResponse:
+    """This is very similar to UserInfo, but not quite the same."""
+
+    name: str
+    user_type: Optional[str]
+    is_guest: bool
+    admin: bool
+    deactivated: bool
+    shadow_banned: bool
+    displayname: Optional[str]
+    avatar_url: Optional[str]
+    creation_ts: Optional[int]
+    approved: bool
+    erased: bool
+    last_seen_ts: int
+    locked: bool
+
+
 class DataStore(
     EventsBackgroundUpdatesStore,
     ExperimentalFeaturesStore,
@@ -156,7 +177,7 @@ class DataStore(
         approved: bool = True,
         not_user_types: Optional[List[str]] = None,
         locked: bool = False,
-    ) -> Tuple[List[JsonDict], int]:
+    ) -> Tuple[List[UserPaginateResponse], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
@@ -182,7 +203,7 @@ class DataStore(
 
         def get_users_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[JsonDict], int]:
+        ) -> Tuple[List[UserPaginateResponse], int]:
             filters = []
             args: list = []
 
@@ -282,13 +303,24 @@ class DataStore(
             """
             args += [limit, start]
             txn.execute(sql, args)
-            users = self.db_pool.cursor_to_dict(txn)
-
-            # some of those boolean values are returned as integers when we're on SQLite
-            columns_to_boolify = ["erased"]
-            for user in users:
-                for column in columns_to_boolify:
-                    user[column] = bool(user[column])
+            users = [
+                UserPaginateResponse(
+                    name=row[0],
+                    user_type=row[1],
+                    is_guest=bool(row[2]),
+                    admin=bool(row[3]),
+                    deactivated=bool(row[4]),
+                    shadow_banned=bool(row[5]),
+                    displayname=row[6],
+                    avatar_url=row[7],
+                    creation_ts=row[8],
+                    approved=bool(row[9]),
+                    erased=bool(row[10]),
+                    last_seen_ts=row[11],
+                    locked=bool(row[12]),
+                )
+                for row in txn
+            ]
 
             return users, count
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d7482a1f4e..07f9b65af3 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         )
 
         # Invalidate the cache for any ignored users which were added or removed.
-        for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
-            self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.ignored_by,
+            [
+                (ignored_user_id,)
+                for ignored_user_id in (
+                    previously_ignored_users ^ currently_ignored_users
+                )
+            ],
+        )
         self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
 
     async def remove_account_data_for_user(
@@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 )
 
                 # Invalidate the cache for ignored users which were removed.
-                for ignored_user_id in previously_ignored_users:
-                    self._invalidate_cache_and_stream(
-                        txn, self.ignored_by, (ignored_user_id,)
-                    )
+                self._invalidate_cache_and_stream_bulk(
+                    txn,
+                    self.ignored_by,
+                    [
+                        (ignored_user_id,)
+                        for ignored_user_id in previously_ignored_users
+                    ],
+                )
 
                 # Invalidate for this user the cache tracking ignored users.
                 self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 4d0470ffd9..d7232f566b 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
+    def _invalidate_cache_and_stream_bulk(
+        self,
+        txn: LoggingTransaction,
+        cache_func: CachedFunction,
+        key_tuples: Collection[Tuple[Any, ...]],
+    ) -> None:
+        """A bulk version of _invalidate_cache_and_stream.
+
+        Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
+        for each key-tuple over replication.
+
+        This implementation is more efficient than a loop which repeatedly calls the
+        non-bulk version.
+        """
+        if not key_tuples:
+            return
+
+        for keys in key_tuples:
+            txn.call_after(cache_func.invalidate, keys)
+
+        self._send_invalidation_to_replication_bulk(
+            txn, cache_func.__name__, key_tuples
+        )
+
     def _invalidate_all_cache_and_stream(
         self, txn: LoggingTransaction, cache_func: CachedFunction
     ) -> None:
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if isinstance(self.database_engine, PostgresEngine):
             assert self._cache_id_gen is not None
 
-            # get_next() returns a context manager which is designed to wrap
-            # the transaction. However, we want to only get an ID when we want
-            # to use it, here, so we need to call __enter__ manually, and have
-            # __exit__ called after the transaction finishes.
             stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
+    def _send_invalidation_to_replication_bulk(
+        self,
+        txn: LoggingTransaction,
+        cache_name: str,
+        key_tuples: Collection[Tuple[Any, ...]],
+    ) -> None:
+        """Announce the invalidation of multiple (but not all) cache entries.
+
+        This is more efficient than repeated calls to the non-bulk version. It should
+        NOT be used to invalidating the entire cache: use
+        `_send_invalidation_to_replication` with keys=None.
+
+        Note that this does *not* invalidate the cache locally.
+
+        Args:
+            txn
+            cache_name
+            key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            assert self._cache_id_gen is not None
+
+            stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
+            ts = self._clock.time_msec()
+            txn.call_after(self.hs.get_notifier().on_new_replication_data)
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="cache_invalidation_stream_by_instance",
+                keys=(
+                    "stream_id",
+                    "instance_name",
+                    "cache_func",
+                    "keys",
+                    "invalidation_ts",
+                ),
+                values=[
+                    # We convert key_tuples to a list here because psycopg2 serialises
+                    # lists as pq arrrays, but serialises tuples as "composite types".
+                    # (We need an array because the `keys` column has type `[]text`.)
+                    # See:
+                    #     https://www.psycopg.org/docs/usage.html#adapt-list
+                    #     https://www.psycopg.org/docs/usage.html#adapt-tuple
+                    (stream_id, self._instance_name, cache_name, list(key_tuple), ts)
+                    for stream_id, key_tuple in zip(stream_ids, key_tuples)
+                ],
+            )
+
     def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token_for_writer(instance_name)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3e7425d4a6..02dddd1da4 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,14 +450,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         user_id: str,
         device_id: Optional[str],
         up_to_stream_id: int,
-        limit: Optional[int] = None,
     ) -> int:
         """
         Args:
             user_id: The recipient user_id.
             device_id: The recipient device_id.
             up_to_stream_id: Where to delete messages up to.
-            limit: maximum number of messages to delete
 
         Returns:
             The number of messages deleted.
@@ -478,32 +476,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
-            limit_statement = "" if limit is None else f"LIMIT {limit}"
-            sql = f"""
-                DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
-                  SELECT MAX(stream_id) FROM (
-                    SELECT stream_id FROM device_inbox
-                    WHERE user_id = ? AND device_id = ? AND stream_id <= ?
-                    ORDER BY stream_id
-                    {limit_statement}
-                  ) AS q1
-                )
-                """
-            txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
-            return txn.rowcount
-
-        count = await self.db_pool.runInteraction(
-            "delete_messages_for_device", delete_messages_for_device_txn
-        )
+        from_stream_id = None
+        count = 0
+        while True:
+            from_stream_id, loop_count = await self.delete_messages_for_device_between(
+                user_id,
+                device_id,
+                from_stream_id=from_stream_id,
+                to_stream_id=up_to_stream_id,
+                limit=1000,
+            )
+            count += loop_count
+            if from_stream_id is None:
+                break
 
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
-        # In this case we don't know if we hit the limit or the delete is complete
-        # so let's not update the cache.
-        if count == limit:
-            return count
-
         # Update the cache, ensuring that we only ever increase the value
         updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
@@ -515,6 +503,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return count
 
     @trace
+    async def delete_messages_for_device_between(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        from_stream_id: Optional[int],
+        to_stream_id: int,
+        limit: int,
+    ) -> Tuple[Optional[int], int]:
+        """Delete N device messages between the stream IDs, returning the
+        highest stream ID deleted (or None if all messages in the range have
+        been deleted) and the number of messages deleted.
+
+        This is more efficient than `delete_messages_for_device` when calling in
+        a loop to batch delete messages.
+        """
+
+        # Keeping track of a lower bound of stream ID where we've deleted
+        # everything below makes the queries much faster. Otherwise, every time
+        # we scan for rows to delete we'd re-scan across all the rows that have
+        # previously deleted (until the next table VACUUM).
+
+        if from_stream_id is None:
+            # Minimum device stream ID is 1.
+            from_stream_id = 0
+
+        def delete_messages_for_device_between_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Optional[int], int]:
+            txn.execute(
+                """
+                SELECT MAX(stream_id) FROM (
+                    SELECT stream_id FROM device_inbox
+                    WHERE user_id = ? AND device_id = ?
+                        AND ? < stream_id AND stream_id <= ?
+                    ORDER BY stream_id
+                    LIMIT ?
+                ) AS d
+                """,
+                (user_id, device_id, from_stream_id, to_stream_id, limit),
+            )
+            row = txn.fetchone()
+            if row is None or row[0] is None:
+                return None, 0
+
+            (max_stream_id,) = row
+
+            txn.execute(
+                """
+                DELETE FROM device_inbox
+                WHERE user_id = ? AND device_id = ?
+                AND ? < stream_id AND stream_id <= ?
+                """,
+                (user_id, device_id, from_stream_id, max_stream_id),
+            )
+
+            num_deleted = txn.rowcount
+            if num_deleted < limit:
+                return None, num_deleted
+
+            return max_stream_id, num_deleted
+
+        return await self.db_pool.runInteraction(
+            "delete_messages_for_device_between",
+            delete_messages_for_device_between_txn,
+            db_autocommit=True,  # We don't need to run in a transaction
+        )
+
+    @trace
     async def get_new_device_msgs_for_remote(
         self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
     ) -> Tuple[List[JsonDict], int]:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 49edbb9e06..775abbac79 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             A dict containing the device information, or `None` if the device does not
             exist.
         """
-        return await self.db_pool.simple_select_one(
-            table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
-            retcols=("user_id", "device_id", "display_name"),
-            desc="get_device",
-            allow_none=True,
-        )
-
-    async def get_device_opt(
-        self, user_id: str, device_id: str
-    ) -> Optional[Dict[str, Any]]:
-        """Retrieve a device. Only returns devices that are not marked as
-        hidden.
-
-        Args:
-            user_id: The ID of the user which owns the device
-            device_id: The ID of the device to retrieve
-        Returns:
-            A dict containing the device information, or None if the device does not exist.
-        """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
             allow_none=True,
         )
+        if row is None:
+            return None
+        return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
 
     async def get_devices_by_user(
         self, user_id: str
@@ -703,7 +686,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             key_names=("destination", "user_id"),
             key_values=[(destination, user_id) for user_id, _ in rows],
             value_names=("stream_id",),
-            value_values=((stream_id,) for _, stream_id in rows),
+            value_values=[(stream_id,) for _, stream_id in rows],
         )
 
         # Delete all sent outbound pokes
@@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             retcols=["device_id", "device_data"],
             allow_none=True,
         )
-        return (
-            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
-        )
+        return (row[0], json_decoder.decode(row[1])) if row else None
 
     def _store_dehydrated_device_txn(
         self,
@@ -1620,7 +1601,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         #
         # For each duplicate, we delete all the existing rows and put one back.
 
-        KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
         last_row = progress.get(
             "last_row",
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1628,44 +1608,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
-                [(x, last_row[x]) for x in KEY_COLS]
+                [
+                    ("stream_id", last_row["stream_id"]),
+                    ("destination", last_row["destination"]),
+                    ("user_id", last_row["user_id"]),
+                    ("device_id", last_row["device_id"]),
+                ]
             )
-            sql = """
+            sql = f"""
                 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
                 FROM device_lists_outbound_pokes
-                WHERE %s
-                GROUP BY %s
+                WHERE {clause}
+                GROUP BY stream_id, destination, user_id, device_id
                 HAVING count(*) > 1
-                ORDER BY %s
+                ORDER BY stream_id, destination, user_id, device_id
                 LIMIT ?
-                """ % (
-                clause,  # WHERE
-                ",".join(KEY_COLS),  # GROUP BY
-                ",".join(KEY_COLS),  # ORDER BY
-            )
+                """
             txn.execute(sql, args + [batch_size])
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
 
-            row = None
-            for row in rows:
+            stream_id, destination, user_id, device_id = None, None, None, None
+            for stream_id, destination, user_id, device_id, _ in rows:
                 self.db_pool.simple_delete_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    {x: row[x] for x in KEY_COLS},
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
                 )
 
-                row["sent"] = False
                 self.db_pool.simple_insert_txn(
                     txn,
                     "device_lists_outbound_pokes",
-                    row,
+                    {
+                        "stream_id": stream_id,
+                        "destination": destination,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "sent": False,
+                    },
                 )
 
-            if row:
+            if rows:
                 self.db_pool.updates._background_update_progress_txn(
                     txn,
                     BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
-                    {"last_row": row},
+                    {
+                        "last_row": {
+                            "stream_id": stream_id,
+                            "destination": destination,
+                            "user_id": user_id,
+                            "device_id": device_id,
+                        }
+                    },
                 )
 
             return len(rows)
@@ -2309,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         `FALSE` have not been converted.
         """
 
-        row = await self.db_pool.simple_select_one(
-            table="device_lists_changes_converted_stream_position",
-            keyvalues={},
-            retcols=["stream_id", "room_id"],
-            desc="get_device_change_last_converted_pos",
+        return cast(
+            Tuple[int, str],
+            await self.db_pool.simple_select_one(
+                table="device_lists_changes_converted_stream_position",
+                keyvalues={},
+                retcols=["stream_id", "room_id"],
+                desc="get_device_change_last_converted_pos",
+            ),
         )
-        return row["stream_id"], row["room_id"]
 
     async def set_device_change_last_converted_pos(
         self,
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index ad904a26a6..fae23c3407 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
                     # it isn't there.
                     raise StoreError(404, "No backup with that version exists")
 
-            result = self.db_pool.simple_select_one_txn(
-                txn,
-                table="e2e_room_keys_versions",
-                keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
-                retcols=("version", "algorithm", "auth_data", "etag"),
-                allow_none=False,
+            row = cast(
+                Tuple[int, str, str, Optional[int]],
+                self.db_pool.simple_select_one_txn(
+                    txn,
+                    table="e2e_room_keys_versions",
+                    keyvalues={
+                        "user_id": user_id,
+                        "version": this_version,
+                        "deleted": 0,
+                    },
+                    retcols=("version", "algorithm", "auth_data", "etag"),
+                    allow_none=False,
+                ),
             )
-            assert result is not None  # see comment on `simple_select_one_txn`
-            result["auth_data"] = db_to_json(result["auth_data"])
-            result["version"] = str(result["version"])
-            if result["etag"] is None:
-                result["etag"] = 0
-            return result
+            return {
+                "auth_data": db_to_json(row[2]),
+                "version": str(row[0]),
+                "algorithm": row[1],
+                "etag": 0 if row[3] is None else row[3],
+            }
 
         return await self.db_pool.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4f96ac25c7..9e98729330 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
             device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
             device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
-
-            if (user_id, device_id) in seen_user_device:
-                continue
             seen_user_device.add((user_id, device_id))
-            self._invalidate_cache_and_stream(
-                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
-            )
+
+        self._invalidate_cache_and_stream_bulk(
+            txn, self.get_e2e_unused_fallback_key_types, seen_user_device
+        )
 
         return results
 
@@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             if row is None:
                 continue
 
-            key_id = row["key_id"]
-            key_json = row["key_json"]
-            used = row["used"]
+            key_id, key_json, used = row
 
             # Mark fallback key as used if not already.
             if not used and mark_as_used:
@@ -1376,17 +1372,62 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
         )
 
-        seen_user_device: Set[Tuple[str, str]] = set()
-        for user_id, device_id, _, _, _ in otk_rows:
-            if (user_id, device_id) in seen_user_device:
-                continue
-            seen_user_device.add((user_id, device_id))
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
+        seen_user_device = {
+            (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
+        }
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.count_e2e_one_time_keys,
+            seen_user_device,
+        )
 
         return otk_rows
 
+    async def get_master_cross_signing_key_updatable_before(
+        self, user_id: str
+    ) -> Tuple[bool, Optional[int]]:
+        """Get time before which a master cross-signing key may be replaced without UIA.
+
+        (UIA means "User-Interactive Auth".)
+
+        There are three cases to distinguish:
+         (1) No master cross-signing key.
+         (2) The key exists, but there is no replace-without-UI timestamp in the DB.
+         (3) The key exists, and has such a timestamp recorded.
+
+        Returns: a 2-tuple of:
+          - a boolean: is there a master cross-signing key already?
+          - an optional timestamp, directly taken from the DB.
+
+        In terms of the cases above, these are:
+         (1) (False, None).
+         (2) (True, None).
+         (3) (True, <timestamp in ms>).
+
+        """
+
+        def impl(txn: LoggingTransaction) -> Tuple[bool, Optional[int]]:
+            # We want to distinguish between three cases:
+            txn.execute(
+                """
+                SELECT updatable_without_uia_before_ms
+                FROM e2e_cross_signing_keys
+                WHERE user_id = ? AND keytype = 'master'
+                ORDER BY stream_id DESC
+                LIMIT 1
+            """,
+                (user_id,),
+            )
+            row = cast(Optional[Tuple[Optional[int]]], txn.fetchone())
+            if row is None:
+                return False, None
+            return True, row[0]
+
+        return await self.db_pool.runInteraction(
+            "e2e_cross_signing_keys",
+            impl,
+        )
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def __init__(
@@ -1634,3 +1675,42 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             ],
             desc="add_e2e_signing_key",
         )
+
+    async def allow_master_cross_signing_key_replacement_without_uia(
+        self, user_id: str, duration_ms: int
+    ) -> Optional[int]:
+        """Mark this user's latest master key as being replaceable without UIA.
+
+        Said replacement will only be permitted for a short time after calling this
+        function. That time period is controlled by the duration argument.
+
+        Returns:
+            None, if there is no such key.
+            Otherwise, the timestamp before which replacement is allowed without UIA.
+        """
+        timestamp = self._clock.time_msec() + duration_ms
+
+        def impl(txn: LoggingTransaction) -> Optional[int]:
+            txn.execute(
+                """
+                UPDATE e2e_cross_signing_keys
+                SET updatable_without_uia_before_ms = ?
+                WHERE stream_id = (
+                    SELECT stream_id
+                    FROM e2e_cross_signing_keys
+                    WHERE user_id = ? AND keytype = 'master'
+                    ORDER BY stream_id DESC
+                    LIMIT 1
+                )
+            """,
+                (timestamp, user_id),
+            )
+            if txn.rowcount == 0:
+                return None
+
+            return timestamp
+
+        return await self.db_pool.runInteraction(
+            "allow_master_cross_signing_key_replacement_without_uia",
+            impl,
+        )
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f1b0991503..7e992ca4a2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
         room = await self.get_room(room_id)  # type: ignore[attr-defined]
-        if room["has_auth_chain_index"]:
+        # If the room has an auth chain index.
+        if room[1]:
             try:
                 return await self.db_pool.runInteraction(
                     "get_auth_chain_ids_chains",
@@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
         room = await self.get_room(room_id)  # type: ignore[attr-defined]
-        if room["has_auth_chain_index"]:
+        # If the room has an auth chain index.
+        if room[1]:
             try:
                 return await self.db_pool.runInteraction(
                     "get_auth_chain_difference_chains",
@@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             )
 
             if event_lookup_result is not None:
+                event_type, depth, stream_ordering = event_lookup_result
                 logger.debug(
                     "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
                     room_id,
                     seed_event_id,
-                    event_lookup_result["depth"],
-                    event_lookup_result["stream_ordering"],
-                    event_lookup_result["type"],
+                    depth,
+                    stream_ordering,
+                    event_type,
                 )
 
-                if event_lookup_result["depth"]:
-                    queue.put(
-                        (
-                            -event_lookup_result["depth"],
-                            -event_lookup_result["stream_ordering"],
-                            seed_event_id,
-                            event_lookup_result["type"],
-                        )
-                    )
+                if depth:
+                    queue.put((-depth, -stream_ordering, seed_event_id, event_type))
 
         while not queue.empty() and len(event_id_results) < limit:
             try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3c1492e3ad..5207cc0f4e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState:
     Attributes:
         to_delete: List of type/state_keys to delete from current state
         to_insert: Map of state to upsert into current state
-        no_longer_in_room: The server is not longer in the room, so the room
+        no_longer_in_room: The server is no longer in the room, so the room
             should e.g. be removed from `current_state_events` table.
     """
 
@@ -131,22 +131,25 @@ class PersistEventsStore:
     @trace
     async def _persist_events_and_state_updates(
         self,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
-        forward extremities tables.
+                forward extremities tables.
+
+        Assumes that we are only persisting events for one room at a time.
 
         Args:
+            room_id:
             events_and_contexts:
-            state_delta_for_room: Map from room_id to the delta to apply to
-                room state
-            new_forward_extremities: Map from room_id to set of event IDs
-                that are the new forward extremities of the room.
+            state_delta_for_room: The delta to apply to the room state
+            new_forward_extremities: A set of event IDs that are the new forward
+                extremities of the room.
             use_negative_stream_ordering: Whether to start stream_ordering on
                 the negative side and decrement. This should be set as True
                 for backfilled events because backfilled events get a negative
@@ -196,6 +199,7 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
+                room_id=room_id,
                 events_and_contexts=events_and_contexts,
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
@@ -221,9 +225,9 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, latest_event_ids in new_forward_extremities.items():
+            if new_forward_extremities:
                 self.store.get_latest_event_ids_in_room.prefill(
-                    (room_id,), frozenset(latest_event_ids)
+                    (room_id,), frozenset(new_forward_extremities)
                 )
 
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@@ -336,10 +340,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         *,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -347,8 +352,11 @@ class PersistEventsStore:
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Args:
             txn
+            room_id: The room the events are from
             events_and_contexts: events to persist
             inhibit_local_membership_updates: Stop the local_current_membership
                 from being updated by these events. This should be set to True
@@ -357,10 +365,9 @@ class PersistEventsStore:
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room: The current-state delta for each room.
-            new_forward_extremities: The new forward extremities for each room.
-                For each room, a list of the event ids which are the forward
-                extremities.
+            state_delta_for_room: The current-state delta for the room.
+            new_forward_extremities: The new forward extremities for the room:
+                a set of the event ids which are the forward extremities.
 
         Raises:
             PartialStateConflictError: if attempting to persist a partial state event in
@@ -376,14 +383,13 @@ class PersistEventsStore:
         #
         # Annoyingly SQLite doesn't support row level locking.
         if isinstance(self.database_engine, PostgresEngine):
-            for room_id in {e.room_id for e, _ in events_and_contexts}:
-                txn.execute(
-                    "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
-                    (room_id,),
-                )
-                row = txn.fetchone()
-                if row is None:
-                    raise Exception(f"Room does not exist {room_id}")
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+                (room_id,),
+            )
+            row = txn.fetchone()
+            if row is None:
+                raise Exception(f"Room does not exist {room_id}")
 
         # stream orderings should have been assigned by now
         assert min_stream_order
@@ -419,7 +425,9 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+        self._update_room_depths_txn(
+            txn, room_id, events_and_contexts=events_and_contexts
+        )
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -432,11 +440,13 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremities,
-            max_stream_order=max_stream_order,
-        )
+        if new_forward_extremities:
+            self._update_forward_extremities_txn(
+                txn,
+                room_id,
+                new_forward_extremities=new_forward_extremities,
+                max_stream_order=max_stream_order,
+            )
 
         self._persist_transaction_ids_txn(txn, events_and_contexts)
 
@@ -464,7 +474,10 @@ class PersistEventsStore:
         # We call this last as it assumes we've inserted the events into
         # room_memberships, where applicable.
         # NB: This function invalidates all state related caches
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+        if state_delta_for_room:
+            self._update_current_state_txn(
+                txn, room_id, state_delta_for_room, min_stream_order
+            )
 
     def _persist_event_auth_chain_txn(
         self,
@@ -1026,74 +1039,75 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "update_current_state",
                 self._update_current_state_txn,
-                state_delta_by_room={room_id: state_delta},
+                room_id,
+                delta_state=state_delta,
                 stream_id=stream_ordering,
             )
 
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
-        state_delta_by_room: Dict[str, DeltaState],
+        room_id: str,
+        delta_state: DeltaState,
         stream_id: int,
     ) -> None:
-        for room_id, delta_state in state_delta_by_room.items():
-            to_delete = delta_state.to_delete
-            to_insert = delta_state.to_insert
-
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
+        to_delete = delta_state.to_delete
+        to_insert = delta_state.to_insert
+
+        # Figure out the changes of membership to invalidate the
+        # `get_rooms_for_user` cache.
+        # We find out which membership events we may have deleted
+        # and which we have added, then we invalidate the caches for all
+        # those users.
+        members_changed = {
+            state_key
+            for ev_type, state_key in itertools.chain(to_delete, to_insert)
+            if ev_type == EventTypes.Member
+        }
 
-            if delta_state.no_longer_in_room:
-                # Server is no longer in the room so we delete the room from
-                # current_state_events, being careful we've already updated the
-                # rooms.room_version column (which gets populated in a
-                # background task).
-                self._upsert_room_version_txn(txn, room_id)
+        if delta_state.no_longer_in_room:
+            # Server is no longer in the room so we delete the room from
+            # current_state_events, being careful we've already updated the
+            # rooms.room_version column (which gets populated in a
+            # background task).
+            self._upsert_room_version_txn(txn, room_id)
 
-                # Before deleting we populate the current_state_delta_stream
-                # so that async background tasks get told what happened.
-                sql = """
+            # Before deleting we populate the current_state_delta_stream
+            # so that async background tasks get told what happened.
+            sql = """
                     INSERT INTO current_state_delta_stream
                         (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, room_id, type, state_key, null, event_id
                         FROM current_state_events
                         WHERE room_id = ?
                 """
-                txn.execute(sql, (stream_id, self._instance_name, room_id))
+            txn.execute(sql, (stream_id, self._instance_name, room_id))
 
-                # We also want to invalidate the membership caches for users
-                # that were in the room.
-                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
-                members_changed.update(users_in_room)
+            # We also want to invalidate the membership caches for users
+            # that were in the room.
+            users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+            members_changed.update(users_in_room)
 
-                self.db_pool.simple_delete_txn(
-                    txn,
-                    table="current_state_events",
-                    keyvalues={"room_id": room_id},
-                )
-            else:
-                # We're still in the room, so we update the current state as normal.
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="current_state_events",
+                keyvalues={"room_id": room_id},
+            )
+        else:
+            # We're still in the room, so we update the current state as normal.
 
-                # First we add entries to the current_state_delta_stream. We
-                # do this before updating the current_state_events table so
-                # that we can use it to calculate the `prev_event_id`. (This
-                # allows us to not have to pull out the existing state
-                # unnecessarily).
-                #
-                # The stream_id for the update is chosen to be the minimum of the stream_ids
-                # for the batch of the events that we are persisting; that means we do not
-                # end up in a situation where workers see events before the
-                # current_state_delta updates.
-                #
-                sql = """
+            # First we add entries to the current_state_delta_stream. We
+            # do this before updating the current_state_events table so
+            # that we can use it to calculate the `prev_event_id`. (This
+            # allows us to not have to pull out the existing state
+            # unnecessarily).
+            #
+            # The stream_id for the update is chosen to be the minimum of the stream_ids
+            # for the batch of the events that we are persisting; that means we do not
+            # end up in a situation where workers see events before the
+            # current_state_delta updates.
+            #
+            sql = """
                     INSERT INTO current_state_delta_stream
                     (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, ?, ?, ?, ?, (
@@ -1101,39 +1115,39 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.execute_batch(
-                    sql,
+            txn.execute_batch(
+                sql,
+                (
                     (
-                        (
-                            stream_id,
-                            self._instance_name,
-                            room_id,
-                            etype,
-                            state_key,
-                            to_insert.get((etype, state_key)),
-                            room_id,
-                            etype,
-                            state_key,
-                        )
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
-                # Now we actually update the current_state_events table
+                        stream_id,
+                        self._instance_name,
+                        room_id,
+                        etype,
+                        state_key,
+                        to_insert.get((etype, state_key)),
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
+            # Now we actually update the current_state_events table
 
-                txn.execute_batch(
-                    "DELETE FROM current_state_events"
-                    " WHERE room_id = ? AND type = ? AND state_key = ?",
-                    (
-                        (room_id, etype, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
+            txn.execute_batch(
+                "DELETE FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?",
+                (
+                    (room_id, etype, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
 
-                # We include the membership in the current state table, hence we do
-                # a lookup when we insert. This assumes that all events have already
-                # been inserted into room_memberships.
-                txn.execute_batch(
-                    """INSERT INTO current_state_events
+            # We include the membership in the current state table, hence we do
+            # a lookup when we insert. This assumes that all events have already
+            # been inserted into room_memberships.
+            txn.execute_batch(
+                """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?, ?,
@@ -1141,34 +1155,34 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[0], key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                    ],
-                )
+                [
+                    (room_id, key[0], key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                ],
+            )
 
-            # We now update `local_current_membership`. We do this regardless
-            # of whether we're still in the room or not to handle the case where
-            # e.g. we just got banned (where we need to record that fact here).
-
-            # Note: Do we really want to delete rows here (that we do not
-            # subsequently reinsert below)? While technically correct it means
-            # we have no record of the fact the user *was* a member of the
-            # room but got, say, state reset out of it.
-            if to_delete or to_insert:
-                txn.execute_batch(
-                    "DELETE FROM local_current_membership"
-                    " WHERE room_id = ? AND user_id = ?",
-                    (
-                        (room_id, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                        if etype == EventTypes.Member and self.is_mine_id(state_key)
-                    ),
-                )
+        # We now update `local_current_membership`. We do this regardless
+        # of whether we're still in the room or not to handle the case where
+        # e.g. we just got banned (where we need to record that fact here).
+
+        # Note: Do we really want to delete rows here (that we do not
+        # subsequently reinsert below)? While technically correct it means
+        # we have no record of the fact the user *was* a member of the
+        # room but got, say, state reset out of it.
+        if to_delete or to_insert:
+            txn.execute_batch(
+                "DELETE FROM local_current_membership"
+                " WHERE room_id = ? AND user_id = ?",
+                (
+                    (room_id, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                    if etype == EventTypes.Member and self.is_mine_id(state_key)
+                ),
+            )
 
-            if to_insert:
-                txn.execute_batch(
-                    """INSERT INTO local_current_membership
+        if to_insert:
+            txn.execute_batch(
+                """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?,
@@ -1176,29 +1190,27 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
-                    ],
-                )
-
-            txn.call_after(
-                self.store._curr_state_delta_stream_cache.entity_has_changed,
-                room_id,
-                stream_id,
+                [
+                    (room_id, key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                    if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                ],
             )
 
-            # Invalidate the various caches
-            self.store._invalidate_state_caches_and_stream(
-                txn, room_id, members_changed
-            )
+        txn.call_after(
+            self.store._curr_state_delta_stream_cache.entity_has_changed,
+            room_id,
+            stream_id,
+        )
 
-            # Check if any of the remote membership changes requires us to
-            # unsubscribe from their device lists.
-            self.store.handle_potentially_left_users_txn(
-                txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
-            )
+        # Invalidate the various caches
+        self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+
+        # Check if any of the remote membership changes requires us to
+        # unsubscribe from their device lists.
+        self.store.handle_potentially_left_users_txn(
+            txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+        )
 
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
@@ -1232,23 +1244,19 @@ class PersistEventsStore:
     def _update_forward_extremities_txn(
         self,
         txn: LoggingTransaction,
-        new_forward_extremities: Dict[str, Set[str]],
+        room_id: str,
+        new_forward_extremities: Set[str],
         max_stream_order: int,
     ) -> None:
-        for room_id in new_forward_extremities.keys():
-            self.db_pool.simple_delete_txn(
-                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
-            )
+        self.db_pool.simple_delete_txn(
+            txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+        )
 
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             keys=("event_id", "room_id"),
-            values=[
-                (ev_id, room_id)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for ev_id in new_extrem
-            ],
+            values=[(ev_id, room_id) for ev_id in new_forward_extremities],
         )
         # We now insert into stream_ordering_to_exterm a mapping from room_id,
         # new stream_ordering to new forward extremeties in the room.
@@ -1260,8 +1268,7 @@ class PersistEventsStore:
             keys=("room_id", "event_id", "stream_ordering"),
             values=[
                 (room_id, event_id, max_stream_order)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for event_id in new_extrem
+                for event_id in new_forward_extremities
             ],
         )
 
@@ -1298,36 +1305,45 @@ class PersistEventsStore:
     def _update_room_depths_txn(
         self,
         txn: LoggingTransaction,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
     ) -> None:
         """Update min_depth for each room
 
         Args:
             txn: db connection
+            room_id: The room ID
             events_and_contexts: events we are persisting
         """
-        depth_updates: Dict[str, int] = {}
+        stream_ordering: Optional[int] = None
+        depth_update = 0
         for event, context in events_and_contexts:
-            # Then update the `stream_ordering` position to mark the latest
-            # event as the front of the room. This should not be done for
-            # backfilled events because backfilled events have negative
-            # stream_ordering and happened in the past so we know that we don't
-            # need to update the stream_ordering tip/front for the room.
+            # Don't update the stream ordering for backfilled events because
+            # backfilled events have negative stream_ordering and happened in the
+            # past, so we know that we don't need to update the stream_ordering
+            # tip/front for the room.
             assert event.internal_metadata.stream_ordering is not None
             if event.internal_metadata.stream_ordering >= 0:
-                txn.call_after(
-                    self.store._events_stream_cache.entity_has_changed,
-                    event.room_id,
-                    event.internal_metadata.stream_ordering,
-                )
+                if stream_ordering is None:
+                    stream_ordering = event.internal_metadata.stream_ordering
+                else:
+                    stream_ordering = max(
+                        stream_ordering, event.internal_metadata.stream_ordering
+                    )
 
             if not event.internal_metadata.is_outlier() and not context.rejected:
-                depth_updates[event.room_id] = max(
-                    event.depth, depth_updates.get(event.room_id, event.depth)
-                )
+                depth_update = max(event.depth, depth_update)
 
-        for room_id, depth in depth_updates.items():
-            self._update_min_depth_for_room_txn(txn, room_id, depth)
+        # Then update the `stream_ordering` position to mark the latest event as
+        # the front of the room.
+        if stream_ordering is not None:
+            txn.call_after(
+                self.store._events_stream_cache.entity_has_changed,
+                room_id,
+                stream_ordering,
+            )
+
+        self._update_min_depth_for_room_txn(txn, room_id, depth_update)
 
     def _update_outliers_txn(
         self,
@@ -1350,13 +1366,19 @@ class PersistEventsStore:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
-        txn.execute(
-            "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
-            % (",".join(["?"] * len(events_and_contexts)),),
-            [event.event_id for event, _ in events_and_contexts],
+        rows = cast(
+            List[Tuple[str, bool]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                "events",
+                "event_id",
+                [event.event_id for event, _ in events_and_contexts],
+                keyvalues={},
+                retcols=("event_id", "outlier"),
+            ),
         )
 
-        have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn))
+        have_persisted = dict(rows)
 
         logger.debug(
             "_update_outliers_txn: events=%s have_persisted=%s",
@@ -1454,7 +1476,7 @@ class PersistEventsStore:
             txn,
             table="event_json",
             keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
-            values=(
+            values=[
                 (
                     event.event_id,
                     event.room_id,
@@ -1463,7 +1485,7 @@ class PersistEventsStore:
                     event.format_version,
                 )
                 for event, _ in events_and_contexts
-            ),
+            ],
         )
 
         self.db_pool.simple_insert_many_txn(
@@ -1486,7 +1508,7 @@ class PersistEventsStore:
                 "state_key",
                 "rejection_reason",
             ),
-            values=(
+            values=[
                 (
                     self._instance_name,
                     event.internal_metadata.stream_ordering,
@@ -1505,7 +1527,7 @@ class PersistEventsStore:
                     context.rejected,
                 )
                 for event, context in events_and_contexts
-            ),
+            ],
         )
 
         # If we're persisting an unredacted event we go and ensure
@@ -1528,11 +1550,11 @@ class PersistEventsStore:
             txn,
             table="state_events",
             keys=("event_id", "room_id", "type", "state_key"),
-            values=(
+            values=[
                 (event.event_id, event.room_id, event.type, event.state_key)
                 for event, _ in events_and_contexts
                 if event.is_state()
-            ),
+            ],
         )
 
     def _store_rejected_events_txn(
@@ -1912,8 +1934,7 @@ class PersistEventsStore:
         if row is None:
             return
 
-        redacted_relates_to = row["relates_to_id"]
-        rel_type = row["relation_type"]
+        redacted_relates_to, rel_type = row
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 0061805150..0c91f19c8e 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -425,7 +425,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         """Background update to clean out extremities that should have been
         deleted previously.
 
-        Mainly used to deal with the aftermath of #5269.
+        Mainly used to deal with the aftermath of https://github.com/matrix-org/synapse/issues/5269.
         """
 
         # This works by first copying all existing forward extremities into the
@@ -558,7 +558,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             )
 
             logger.info(
-                "Deleted %d forward extremities of %d checked, to clean up #5269",
+                "Deleted %d forward extremities of %d checked, to clean up matrix-org/synapse#5269",
                 deleted,
                 len(original_set),
             )
@@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 )
 
                 # Iterate the parent IDs and invalidate caches.
-                for parent_id in {r[1] for r in relations_to_insert}:
-                    cache_tuple = (parent_id,)
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
-                    )
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
-                    )
+                cache_tuples = {(r[1],) for r in relations_to_insert}
+                self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
+                    txn, self.get_relations_for_event, cache_tuples  # type: ignore[attr-defined]
+                )
+                self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
+                    txn, self.get_thread_summary, cache_tuples  # type: ignore[attr-defined]
+                )
 
             if results:
                 latest_event_id = results[-1][0]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5bf864c1fb..4125059061 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1312,7 +1312,8 @@ class EventsWorkerStore(SQLBaseStore):
             room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
-                # arrived before #6983 landed. For all other events, we should have
+                # arrived before https://github.com/matrix-org/synapse/issues/6983
+                # landed. For all other events, we should have
                 # an entry in the 'rooms' table.
                 #
                 # However, the 'out_of_band_membership' flag is unreliable for older
@@ -1323,7 +1324,8 @@ class EventsWorkerStore(SQLBaseStore):
                         "Room %s for event %s is unknown" % (d["room_id"], event_id)
                     )
 
-                # so, assuming this is an out-of-band-invite that arrived before #6983
+                # so, assuming this is an out-of-band-invite that arrived before
+                # https://github.com/matrix-org/synapse/issues/6983
                 # landed, we know that the room version must be v5 or earlier (because
                 # v6 hadn't been invented at that point, so invites from such rooms
                 # would have been rejected.)
@@ -1998,7 +2000,7 @@ class EventsWorkerStore(SQLBaseStore):
         if not res:
             raise SynapseError(404, "Could not find event %s" % (event_id,))
 
-        return int(res["topological_ordering"]), int(res["stream_ordering"])
+        return int(res[0]), int(res[1])
 
     async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
         """Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ce88772f9e..c700872fdc 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore):
             # invalidate takes a tuple corresponding to the params of
             # _get_server_keys_json. _get_server_keys_json only takes one
             # param, which is itself the 2-tuple (server_name, key_id).
-            for key_id in verify_keys:
-                self._invalidate_cache_and_stream(
-                    txn, self._get_server_keys_json, ((server_name, key_id),)
-                )
-                self._invalidate_cache_and_stream(
-                    txn, self.get_server_key_json_for_remote, (server_name, key_id)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self._get_server_keys_json,
+                [((server_name, key_id),) for key_id in verify_keys],
+            )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self.get_server_key_json_for_remote,
+                [(server_name, key_id) for key_id in verify_keys],
+            )
 
         await self.db_pool.runInteraction(
             "store_server_keys_response", store_server_keys_response_txn
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..149135b8b5 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
-    Any,
     Collection,
-    Dict,
     Iterable,
     List,
     Optional,
@@ -26,6 +24,8 @@ from typing import (
     cast,
 )
 
+import attr
+
 from synapse.api.constants import Direction
 from synapse.logging.opentracing import trace
 from synapse.media._base import ThumbnailInfo
@@ -45,6 +45,40 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
 )
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+    media_id: str
+    media_type: str
+    media_length: Optional[int]
+    upload_name: str
+    created_ts: int
+    url_cache: Optional[str]
+    last_access_ts: int
+    quarantined_by: Optional[str]
+    safe_from_quarantine: bool
+    user_id: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+    media_origin: str
+    media_id: str
+    media_type: str
+    media_length: int
+    upload_name: Optional[str]
+    filesystem_id: str
+    created_ts: int
+    last_access_ts: int
+    quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+    response_code: int
+    expires_ts: int
+    og: Union[str, bytes]
+
+
 class MediaSortOrder(Enum):
     """
     Enum to define the sorting method used when returning media with
@@ -116,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             self._drop_media_index_without_method,
         )
 
+        if hs.config.media.can_load_media_repo:
+            self.unused_expiration_time: Optional[
+                int
+            ] = hs.config.media.unused_expiration_time
+        else:
+            self.unused_expiration_time = None
+
     async def _drop_media_index_without_method(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -151,13 +192,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         super().__init__(database, db_conn, hs)
         self.server_name: str = hs.hostname
 
-    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+    async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
         """Get the metadata for a local piece of media
 
         Returns:
             None if the media_id doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -167,11 +208,27 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "last_access_ts",
                 "safe_from_quarantine",
+                "user_id",
             ),
             allow_none=True,
             desc="get_local_media",
         )
+        if row is None:
+            return None
+        return LocalMedia(
+            media_id=media_id,
+            media_type=row[0],
+            media_length=row[1],
+            upload_name=row[2],
+            created_ts=row[3],
+            quarantined_by=row[4],
+            url_cache=row[5],
+            last_access_ts=row[6],
+            safe_from_quarantine=row[7],
+            user_id=row[8],
+        )
 
     async def get_local_media_by_user_paginate(
         self,
@@ -180,7 +237,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         user_id: str,
         order_by: str = MediaSortOrder.CREATED_TS.value,
         direction: Direction = Direction.FORWARDS,
-    ) -> Tuple[List[Dict[str, Any]], int]:
+    ) -> Tuple[List[LocalMedia], int]:
         """Get a paginated list of metadata for a local piece of media
         which an user_id has uploaded
 
@@ -197,7 +254,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
         def get_local_media_by_user_paginate_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[Dict[str, Any]], int]:
+        ) -> Tuple[List[LocalMedia], int]:
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
@@ -217,14 +274,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             sql = """
                 SELECT
-                    "media_id",
-                    "media_type",
-                    "media_length",
-                    "upload_name",
-                    "created_ts",
-                    "last_access_ts",
-                    "quarantined_by",
-                    "safe_from_quarantine"
+                    media_id,
+                    media_type,
+                    media_length,
+                    upload_name,
+                    created_ts,
+                    url_cache,
+                    last_access_ts,
+                    quarantined_by,
+                    safe_from_quarantine,
+                    user_id
                 FROM local_media_repository
                 WHERE user_id = ?
                 ORDER BY {order_by_column} {order}, media_id ASC
@@ -236,7 +295,21 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             args += [limit, start]
             txn.execute(sql, args)
-            media = self.db_pool.cursor_to_dict(txn)
+            media = [
+                LocalMedia(
+                    media_id=row[0],
+                    media_type=row[1],
+                    media_length=row[2],
+                    upload_name=row[3],
+                    created_ts=row[4],
+                    url_cache=row[5],
+                    last_access_ts=row[6],
+                    quarantined_by=row[7],
+                    safe_from_quarantine=bool(row[8]),
+                    user_id=row[9],
+                )
+                for row in txn
+            ]
             return media, count
 
         return await self.db_pool.runInteraction(
@@ -331,6 +404,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     @trace
+    async def store_local_media_id(
+        self,
+        media_id: str,
+        time_now_ms: int,
+        user_id: UserID,
+    ) -> None:
+        await self.db_pool.simple_insert(
+            "local_media_repository",
+            {
+                "media_id": media_id,
+                "created_ts": time_now_ms,
+                "user_id": user_id.to_string(),
+            },
+            desc="store_local_media_id",
+        )
+
+    @trace
     async def store_local_media(
         self,
         media_id: str,
@@ -355,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
+    async def update_local_media(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        media_length: int,
+        user_id: UserID,
+        url_cache: Optional[str] = None,
+    ) -> None:
+        await self.db_pool.simple_update_one(
+            "local_media_repository",
+            keyvalues={
+                "user_id": user_id.to_string(),
+                "media_id": media_id,
+            },
+            updatevalues={
+                "media_type": media_type,
+                "upload_name": upload_name,
+                "media_length": media_length,
+                "url_cache": url_cache,
+            },
+            desc="update_local_media",
+        )
+
     async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
         """Mark a local media as safe or unsafe from quarantining."""
         await self.db_pool.simple_update_one(
@@ -364,51 +478,72 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+    async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
+        """Count the number of pending media for a user.
+
+        Returns:
+            A tuple of two integers: the total pending media requests and the earliest
+            expiration timestamp.
+        """
+
+        def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
+            sql = """
+            SELECT COUNT(*), MIN(created_ts)
+            FROM local_media_repository
+            WHERE user_id = ?
+                AND created_ts > ?
+                AND media_length IS NULL
+            """
+            assert self.unused_expiration_time is not None
+            txn.execute(
+                sql,
+                (
+                    user_id.to_string(),
+                    self._clock.time_msec() - self.unused_expiration_time,
+                ),
+            )
+            row = txn.fetchone()
+            if not row:
+                return 0, 0
+            return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
+
+        return await self.db_pool.runInteraction(
+            "get_pending_media", get_pending_media_txn
+        )
+
+    async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
             # get the most recently cached result (relative to the given ts)
-            sql = (
-                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                " FROM local_media_repository_url_cache"
-                " WHERE url = ? AND download_ts <= ?"
-                " ORDER BY download_ts DESC LIMIT 1"
-            )
+            sql = """
+                SELECT response_code, expires_ts, og
+                FROM local_media_repository_url_cache
+                WHERE url = ? AND download_ts <= ?
+                ORDER BY download_ts DESC LIMIT 1
+            """
             txn.execute(sql, (url, ts))
             row = txn.fetchone()
 
             if not row:
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
-                sql = (
-                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                    " FROM local_media_repository_url_cache"
-                    " WHERE url = ? AND download_ts > ?"
-                    " ORDER BY download_ts ASC LIMIT 1"
-                )
+                sql = """
+                    SELECT response_code, expires_ts, og
+                    FROM local_media_repository_url_cache
+                    WHERE url = ? AND download_ts > ?
+                    ORDER BY download_ts ASC LIMIT 1
+                """
                 txn.execute(sql, (url, ts))
                 row = txn.fetchone()
 
             if not row:
                 return None
 
-            return dict(
-                zip(
-                    (
-                        "response_code",
-                        "etag",
-                        "expires_ts",
-                        "og",
-                        "media_id",
-                        "download_ts",
-                    ),
-                    row,
-                )
-            )
+            return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
 
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
@@ -418,7 +553,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         response_code: int,
         etag: Optional[str],
         expires_ts: int,
-        og: Optional[str],
+        og: str,
         media_id: str,
         download_ts: int,
     ) -> None:
@@ -484,8 +619,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_cached_remote_media(
         self, origin: str, media_id: str
-    ) -> Optional[Dict[str, Any]]:
-        return await self.db_pool.simple_select_one(
+    ) -> Optional[RemoteMedia]:
+        row = await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -494,11 +629,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "upload_name",
                 "created_ts",
                 "filesystem_id",
+                "last_access_ts",
                 "quarantined_by",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
         )
+        if row is None:
+            return row
+        return RemoteMedia(
+            media_origin=origin,
+            media_id=media_id,
+            media_type=row[0],
+            media_length=row[1],
+            upload_name=row[2],
+            created_ts=row[3],
+            filesystem_id=row[4],
+            last_access_ts=row[5],
+            quarantined_by=row[6],
+        )
 
     async def store_cached_remote_media(
         self,
@@ -597,10 +746,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         t_width: int,
         t_height: int,
         t_type: str,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[ThumbnailInfo]:
         """Fetch the thumbnail info of given width, height and type."""
 
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
             keyvalues={
                 "media_origin": origin,
@@ -615,11 +764,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "thumbnail_method",
                 "thumbnail_type",
                 "thumbnail_length",
-                "filesystem_id",
             ),
             allow_none=True,
             desc="get_remote_media_thumbnail",
         )
+        if row is None:
+            return None
+        return ThumbnailInfo(
+            width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+        )
 
     @trace
     async def store_remote_media_thumbnail(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 4b1061e6d7..2911e53310 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -317,7 +317,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
             if user_id:
                 is_support = self.is_support_user_txn(txn, user_id)
                 if not is_support:
-                    # We do this manually here to avoid hitting #6791
+                    # We do this manually here to avoid hitting https://github.com/matrix-org/synapse/issues/6791
                     self.db_pool.simple_upsert_txn(
                         txn,
                         table="monthly_active_users",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 3b444d2d07..0198bb09d2 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
                 # for their user ID.
                 value_values=[(presence_stream_id,) for _ in user_ids],
             )
-            for user_id in user_ids:
-                self._invalidate_cache_and_stream(
-                    txn, self._get_full_presence_stream_token_for_user, (user_id,)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self._get_full_presence_stream_token_for_user,
+                [(user_id,) for user_id in user_ids],
+            )
 
         return await self.db_pool.runInteraction(
             "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 3ba9cc8853..7ed111f632 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 from typing import TYPE_CHECKING, Optional
 
-from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
         return 50
 
     async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
-        try:
-            profile = await self.db_pool.simple_select_one(
-                table="profiles",
-                keyvalues={"full_user_id": user_id.to_string()},
-                retcols=("displayname", "avatar_url"),
-                desc="get_profileinfo",
-            )
-        except StoreError as e:
-            if e.code == 404:
-                # no match
-                return ProfileInfo(None, None)
-            else:
-                raise
-
-        return ProfileInfo(
-            avatar_url=profile["avatar_url"], display_name=profile["displayname"]
+        profile = await self.db_pool.simple_select_one(
+            table="profiles",
+            keyvalues={"full_user_id": user_id.to_string()},
+            retcols=("displayname", "avatar_url"),
+            desc="get_profileinfo",
+            allow_none=True,
         )
+        if profile is None:
+            # no match
+            return ProfileInfo(None, None)
+
+        return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
 
     async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 1e11bf2706..1a5b5731bb 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # so make sure to keep this actually last.
         txn.execute("DROP TABLE events_to_purge")
 
-        for event_id, should_delete in event_rows:
-            self._invalidate_cache_and_stream(
-                txn, self._get_state_group_for_event, (event_id,)
-            )
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self._get_state_group_for_event,
+            [(event_id,) for event_id, _ in event_rows],
+        )
 
-            # XXX: This is racy, since have_seen_events could be called between the
-            #    transaction completing and the invalidation running. On the other hand,
-            #    that's no different to calling `have_seen_events` just before the
-            #    event is deleted from the database.
+        # XXX: This is racy, since have_seen_events could be called between the
+        #    transaction completing and the invalidation running. On the other hand,
+        #    that's no different to calling `have_seen_events` just before the
+        #    event is deleted from the database.
+        self._invalidate_cache_and_stream_bulk(
+            txn,
+            self.have_seen_event,
+            [
+                (room_id, event_id)
+                for event_id, should_delete in event_rows
+                if should_delete
+            ],
+        )
+
+        for event_id, should_delete in event_rows:
             if should_delete:
-                self._invalidate_cache_and_stream(
-                    txn, self.have_seen_event, (room_id, event_id)
-                )
                 self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")
@@ -485,7 +494,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #  - room_tags_revisions
         #       The problem with these is that they are largeish and there is no room_id
         #       index on them. In any case we should be clearing out 'stream' tables
-        #       periodically anyway (#5888)
+        #       periodically anyway (https://github.com/matrix-org/synapse/issues/5888)
 
         self._invalidate_caches_for_room_and_stream(txn, room_id)
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 22025eca56..cf622e195c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import (
     cast,
 )
 
+from twisted.internet import defer
+
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_encoder, unwrapFirstError
+from synapse.util.async_helpers import gather_results
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -249,23 +253,33 @@ class PushRulesWorkerStore(
             user_id: [] for user_id in user_ids
         }
 
-        rows = cast(
-            List[Tuple[str, str, int, int, str, str]],
-            await self.db_pool.simple_select_many_batch(
-                table="push_rules",
-                column="user_name",
-                iterable=user_ids,
-                retcols=(
-                    "user_name",
-                    "rule_id",
-                    "priority_class",
-                    "priority",
-                    "conditions",
-                    "actions",
+        # gatherResults loses all type information.
+        rows, enabled_map_by_user = await make_deferred_yieldable(
+            gather_results(
+                (
+                    cast(
+                        "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
+                        run_in_background(
+                            self.db_pool.simple_select_many_batch,
+                            table="push_rules",
+                            column="user_name",
+                            iterable=user_ids,
+                            retcols=(
+                                "user_name",
+                                "rule_id",
+                                "priority_class",
+                                "priority",
+                                "conditions",
+                                "actions",
+                            ),
+                            desc="bulk_get_push_rules",
+                            batch_size=1000,
+                        ),
+                    ),
+                    run_in_background(self.bulk_get_push_rules_enabled, user_ids),
                 ),
-                desc="bulk_get_push_rules",
-                batch_size=1000,
-            ),
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
         )
 
         # Sort by highest priority_class, then highest priority.
@@ -276,8 +290,6 @@ class PushRulesWorkerStore(
                 (rule_id, priority_class, conditions, actions)
             )
 
-        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
-
         results: Dict[str, FilteredPushRules] = {}
 
         for user_id, rules in raw_rules.items():
@@ -437,27 +449,28 @@ class PushRuleStore(PushRulesWorkerStore):
         before: str,
         after: str,
     ) -> None:
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
-
         relative_to_rule = before or after
 
-        res = self.db_pool.simple_select_one_txn(
-            txn,
-            table="push_rules",
-            keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
-            retcols=["priority_class", "priority"],
-            allow_none=True,
-        )
+        sql = """
+            SELECT priority, priority_class FROM push_rules
+            WHERE user_name = ? AND rule_id = ?
+        """
 
-        if not res:
+        if isinstance(self.database_engine, PostgresEngine):
+            sql += " FOR UPDATE"
+        else:
+            # Annoyingly SQLite doesn't support row level locking, so lock the whole table
+            self.database_engine.lock_table(txn, "push_rules")
+
+        txn.execute(sql, (user_id, relative_to_rule))
+        row = txn.fetchone()
+
+        if row is None:
             raise RuleNotFoundException(
                 "before/after rule not found: %s" % (relative_to_rule,)
             )
 
-        base_priority_class = res["priority_class"]
-        base_rule_priority = res["priority"]
+        base_rule_priority, base_priority_class = row
 
         if base_priority_class != priority_class:
             raise InconsistentRuleException(
@@ -505,9 +518,18 @@ class PushRuleStore(PushRulesWorkerStore):
         conditions_json: str,
         actions_json: str,
     ) -> None:
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
+        if isinstance(self.database_engine, PostgresEngine):
+            # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
+            # then re-select the count/max below.
+            sql = """
+                SELECT * FROM push_rules
+                WHERE user_name = ? and priority_class = ?
+                FOR UPDATE
+            """
+            txn.execute(sql, (user_id, priority_class))
+        else:
+            # Annoyingly SQLite doesn't support row level locking, so lock the whole table
+            self.database_engine.lock_table(txn, "push_rules")
 
         # find the highest priority rule in that class
         sql = (
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 56e8eb16a8..3484ce9ef9 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-        stream_ordering = int(res["stream_ordering"]) if res else None
-        rx_ts = res["received_ts"] if res else 0
+        stream_ordering = int(res[0]) if res else None
+        rx_ts = res[1] if res else 0
 
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..2c3f30e2eb 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     account timestamp as milliseconds since the epoch. None if the account
                     has not been renewed using the current token yet.
         """
-        ret_dict = await self.db_pool.simple_select_one(
-            table="account_validity",
-            keyvalues={"renewal_token": renewal_token},
-            retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
-            desc="get_user_from_renewal_token",
-        )
-
-        return (
-            ret_dict["user_id"],
-            ret_dict["expiration_ts_ms"],
-            ret_dict["token_used_ts_ms"],
+        return cast(
+            Tuple[str, int, Optional[int]],
+            await self.db_pool.simple_select_one(
+                table="account_validity",
+                keyvalues={"renewal_token": renewal_token},
+                retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
+                desc="get_user_from_renewal_token",
+            ),
         )
 
     async def get_renewal_token_for_user(self, user_id: str) -> str:
@@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 updatevalues={"shadow_banned": shadow_banned},
             )
             # In order for this to apply immediately, clear the cache for this user.
-            tokens = self.db_pool.simple_select_onecol_txn(
+            tokens = self.db_pool.simple_select_list_txn(
                 txn,
                 table="access_tokens",
                 keyvalues={"user_id": user_id},
-                retcol="token",
+                retcols=("token",),
+            )
+            self._invalidate_cache_and_stream_bulk(
+                txn, self.get_user_by_access_token, tokens
             )
-            for token in tokens:
-                self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (token,)
-                )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
         await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
@@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             user id, or None if no user id/threepid mapping exists
         """
-        ret = self.db_pool.simple_select_one_txn(
+        return self.db_pool.simple_select_one_onecol_txn(
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
-            ["user_id"],
+            "user_id",
             True,
         )
-        if ret:
-            return ret["user_id"]
-        return None
 
     async def user_add_threepid(
         self,
@@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         if res is None:
             return False
 
+        uses_allowed, pending, completed, expiry_time = res
+
         # Check if the token has expired
         now = self._clock.time_msec()
-        if res["expiry_time"] and res["expiry_time"] < now:
+        if expiry_time and expiry_time < now:
             return False
 
         # Check if the token has been used up
-        if (
-            res["uses_allowed"]
-            and res["pending"] + res["completed"] >= res["uses_allowed"]
-        ):
+        if uses_allowed and pending + completed >= uses_allowed:
             return False
 
         # Otherwise, the token is valid
@@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
             # about None not being indexable.
-            res = cast(
-                Dict[str, Any],
+            pending, completed = cast(
+                Tuple[int, int],
                 self.db_pool.simple_select_one_txn(
                     txn,
                     "registration_tokens",
@@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 "registration_tokens",
                 keyvalues={"token": token},
                 updatevalues={
-                    "completed": res["completed"] + 1,
-                    "pending": res["pending"] - 1,
+                    "completed": completed + 1,
+                    "pending": pending - 1,
                 },
             )
 
@@ -1517,7 +1509,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
     async def get_registration_tokens(
         self, valid: Optional[bool] = None
-    ) -> List[Dict[str, Any]]:
+    ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
         """List all registration tokens. Used by the admin API.
 
         Args:
@@ -1526,34 +1518,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
               Default is None: return all tokens regardless of validity.
 
         Returns:
-            A list of dicts, each containing details of a token.
+            A list of tuples containing:
+                * The token
+                * The number of users allowed (or None)
+                * Whether it is pending
+                * Whether it has been completed
+                * An expiry time (or None if no expiry)
         """
 
         def select_registration_tokens_txn(
             txn: LoggingTransaction, now: int, valid: Optional[bool]
-        ) -> List[Dict[str, Any]]:
+        ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
             if valid is None:
                 # Return all tokens regardless of validity
-                txn.execute("SELECT * FROM registration_tokens")
+                txn.execute(
+                    """
+                    SELECT token, uses_allowed, pending, completed, expiry_time
+                    FROM registration_tokens
+                    """
+                )
 
             elif valid:
                 # Select valid tokens only
-                sql = (
-                    "SELECT * FROM registration_tokens WHERE "
-                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
-                    "AND (expiry_time > ? OR expiry_time IS NULL)"
-                )
+                sql = """
+                SELECT token, uses_allowed, pending, completed, expiry_time
+                FROM registration_tokens
+                WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
+                    AND (expiry_time > ? OR expiry_time IS NULL)
+                """
                 txn.execute(sql, [now])
 
             else:
                 # Select invalid tokens only
-                sql = (
-                    "SELECT * FROM registration_tokens WHERE "
-                    "uses_allowed <= pending + completed OR expiry_time <= ?"
-                )
+                sql = """
+                SELECT token, uses_allowed, pending, completed, expiry_time
+                FROM registration_tokens
+                WHERE uses_allowed <= pending + completed OR expiry_time <= ?
+                """
                 txn.execute(sql, [now])
 
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(
+                List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "select_registration_tokens",
@@ -1571,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             A dict, or None if token doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "registration_tokens",
             keyvalues={"token": token},
             retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
             allow_none=True,
             desc="get_one_registration_token",
         )
+        if row is None:
+            return None
+        return {
+            "token": row[0],
+            "uses_allowed": row[1],
+            "pending": row[2],
+            "completed": row[3],
+            "expiry_time": row[4],
+        }
 
     async def generate_registration_token(
         self, length: int, chars: str
@@ -1700,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 return None
 
             # Get all info about the token so it can be sent in the response
-            return self.db_pool.simple_select_one_txn(
+            result = self.db_pool.simple_select_one_txn(
                 txn,
                 "registration_tokens",
                 keyvalues={"token": token},
@@ -1714,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 allow_none=True,
             )
 
+            if result is None:
+                return result
+
+            return {
+                "token": result[0],
+                "uses_allowed": result[1],
+                "pending": result[2],
+                "completed": result[3],
+                "expiry_time": result[4],
+            }
+
         return await self.db_pool.runInteraction(
             "update_registration_token", _update_registration_token_txn
         )
@@ -1925,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             keyvalues={"token": token},
             updatevalues={"used_ts": ts},
         )
-        user_id = values["user_id"]
-        expiry_ts = values["expiry_ts"]
-        used_ts = values["used_ts"]
-        auth_provider_id = values["auth_provider_id"]
-        auth_provider_session_id = values["auth_provider_session_id"]
+        (
+            user_id,
+            expiry_ts,
+            used_ts,
+            auth_provider_id,
+            auth_provider_session_id,
+        ) = values
 
         # Token was already used
         if used_ts is not None:
@@ -2654,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             )
             tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for token, _, _ in tokens_and_devices:
-                self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (token,)
-                )
+            self._invalidate_cache_and_stream_bulk(
+                txn,
+                self.get_user_by_access_token,
+                [(token,) for token, _, _ in tokens_and_devices],
+            )
 
             txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
@@ -2742,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                     # reason, the next check is on the client secret, which is NOT NULL,
                     # so we don't have to worry about the client secret matching by
                     # accident.
-                    row = {"client_secret": None, "validated_at": None}
+                    row = None, None
                 else:
                     raise ThreepidValidationError("Unknown session_id")
 
-            retrieved_client_secret = row["client_secret"]
-            validated_at = row["validated_at"]
+            retrieved_client_secret, validated_at = row
 
             row = self.db_pool.simple_select_one_txn(
                 txn,
@@ -2761,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 raise ThreepidValidationError(
                     "Validation token not found or has expired"
                 )
-            expires = row["expires"]
-            next_link = row["next_link"]
+            expires, next_link = row
 
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..ef26d5d9d3 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride:
     burst_count: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LargestRoomStats:
+    room_id: str
+    name: Optional[str]
+    canonical_alias: Optional[str]
+    joined_members: int
+    join_rules: Optional[str]
+    guest_access: Optional[str]
+    history_visibility: Optional[str]
+    state_events: int
+    avatar: Optional[str]
+    topic: Optional[str]
+    room_type: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomStats(LargestRoomStats):
+    joined_local_members: int
+    version: Optional[str]
+    creator: Optional[str]
+    encryption: Optional[str]
+    federatable: bool
+    public: bool
+
+
 class RoomSortOrder(Enum):
     """
     Enum to define the sorting method used when returning rooms with get_rooms_paginate
@@ -188,23 +213,33 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
-    async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
+    async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
         """Retrieve a room.
 
         Args:
             room_id: The ID of the room to retrieve.
         Returns:
-            A dict containing the room information, or None if the room is unknown.
+            A tuple containing the room information:
+                * True if the room is public
+                * True if the room has an auth chain index
+
+            or None if the room is unknown.
         """
-        return await self.db_pool.simple_select_one(
-            table="rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
-            desc="get_room",
-            allow_none=True,
+        row = cast(
+            Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
+            await self.db_pool.simple_select_one(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                retcols=("is_public", "has_auth_chain_index"),
+                desc="get_room",
+                allow_none=True,
+            ),
         )
+        if row is None:
+            return row
+        return bool(row[0]), bool(row[1])
 
-    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
+    async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
         """Retrieve room with statistics.
 
         Args:
@@ -215,7 +250,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def get_room_with_stats_txn(
             txn: LoggingTransaction, room_id: str
-        ) -> Optional[Dict[str, Any]]:
+        ) -> Optional[RoomStats]:
             sql = """
                 SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -229,15 +264,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 WHERE room_id = ?
                 """
             txn.execute(sql, [room_id])
-            # Catch error if sql returns empty result to return "None" instead of an error
-            try:
-                res = self.db_pool.cursor_to_dict(txn)[0]
-            except IndexError:
+            row = txn.fetchone()
+            if not row:
                 return None
-
-            res["federatable"] = bool(res["federatable"])
-            res["public"] = bool(res["public"])
-            return res
+            return RoomStats(
+                room_id=row[0],
+                name=row[1],
+                canonical_alias=row[2],
+                joined_members=row[3],
+                joined_local_members=row[4],
+                version=row[5],
+                creator=row[6],
+                encryption=row[7],
+                federatable=bool(row[8]),
+                public=bool(row[9]),
+                join_rules=row[10],
+                guest_access=row[11],
+                history_visibility=row[12],
+                state_events=row[13],
+                avatar=row[14],
+                topic=row[15],
+                room_type=row[16],
+            )
 
         return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
@@ -368,7 +416,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         bounds: Optional[Tuple[int, str]],
         forwards: bool,
         ignore_non_federatable: bool = False,
-    ) -> List[Dict[str, Any]]:
+    ) -> List[LargestRoomStats]:
         """Gets the largest public rooms (where largest is in terms of joined
         members, as tracked in the statistics table).
 
@@ -505,20 +553,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def _get_largest_public_rooms_txn(
             txn: LoggingTransaction,
-        ) -> List[Dict[str, Any]]:
+        ) -> List[LargestRoomStats]:
             txn.execute(sql, query_args)
 
-            results = self.db_pool.cursor_to_dict(txn)
+            results = [
+                LargestRoomStats(
+                    room_id=r[0],
+                    name=r[1],
+                    canonical_alias=r[3],
+                    joined_members=r[4],
+                    join_rules=r[8],
+                    guest_access=r[7],
+                    history_visibility=r[6],
+                    state_events=0,
+                    avatar=r[5],
+                    topic=r[2],
+                    room_type=r[9],
+                )
+                for r in txn
+            ]
 
             if not forwards:
                 results.reverse()
 
             return results
 
-        ret_val = await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
-        return ret_val
 
     @cached(max_entries=10000)
     async def is_room_blocked(self, room_id: str) -> Optional[bool]:
@@ -742,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         )
 
         if row:
-            return RatelimitOverride(
-                messages_per_second=row["messages_per_second"],
-                burst_count=row["burst_count"],
-            )
+            return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
         else:
             return None
 
@@ -1319,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         join.
         """
 
-        result = await self.db_pool.simple_select_one(
-            table="partial_state_rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("join_event_id", "device_lists_stream_id"),
-            desc="get_join_event_id_for_partial_state",
+        return cast(
+            Tuple[str, int],
+            await self.db_pool.simple_select_one(
+                table="partial_state_rooms",
+                keyvalues={"room_id": room_id},
+                retcols=("join_event_id", "device_lists_stream_id"),
+                desc="get_join_event_id_for_partial_state",
+            ),
         )
-        return result["join_event_id"], result["device_lists_stream_id"]
 
     def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
         return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
@@ -2216,7 +2277,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             txn,
             table="partial_state_rooms_servers",
             keys=("room_id", "server_name"),
-            values=((room_id, s) for s in servers),
+            values=[(room_id, s) for s in servers],
         )
         self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
         self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1ed7f2d0ef..60d4a9ef30 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
                 "non-local user %s" % (user_id,),
             )
 
-        results_dict = await self.db_pool.simple_select_one(
-            "local_current_membership",
-            {"room_id": room_id, "user_id": user_id},
-            ("membership", "event_id"),
-            allow_none=True,
-            desc="get_local_current_membership_for_user_in_room",
+        results = cast(
+            Optional[Tuple[str, str]],
+            await self.db_pool.simple_select_one(
+                "local_current_membership",
+                {"room_id": room_id, "user_id": user_id},
+                ("membership", "event_id"),
+                allow_none=True,
+                desc="get_local_current_membership_for_user_in_room",
+            ),
         )
-        if not results_dict:
+        if not results:
             return None, None
 
-        return results_dict.get("membership"), results_dict.get("event_id")
+        return results
 
     @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index dbde9130c6..e25d86818b 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -106,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore):
                 txn,
                 table="event_search",
                 keys=("event_id", "room_id", "key", "value"),
-                values=(
+                values=[
                     (
                         entry.event_id,
                         entry.room_id,
@@ -114,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore):
                         _clean_value_for_search(entry.value),
                     )
                     for entry in entries
-                ),
+                ],
             )
 
         else:
@@ -275,7 +275,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
             # we have to set autocommit, because postgres refuses to
             # CREATE INDEX CONCURRENTLY without it.
-            conn.set_session(autocommit=True)
+            conn.engine.attempt_to_set_autocommit(conn.conn, True)
 
             try:
                 c = conn.cursor()
@@ -301,7 +301,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 # we should now be able to delete the GIST index.
                 c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
             finally:
-                conn.set_session(autocommit=False)
+                conn.engine.attempt_to_set_autocommit(conn.conn, False)
 
         if isinstance(self.database_engine, PostgresEngine):
             await self.db_pool.runWithConnection(create_index)
@@ -323,7 +323,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
             def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
-                conn.set_session(autocommit=True)
+                conn.engine.attempt_to_set_autocommit(conn.conn, True)
                 c = conn.cursor()
 
                 # We create with NULLS FIRST so that when we search *backwards*
@@ -340,7 +340,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                     ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
                     """
                 )
-                conn.set_session(autocommit=False)
+                conn.engine.attempt_to_set_autocommit(conn.conn, False)
 
             await self.db_pool.runWithConnection(create_index)
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2225f8272d..563c275a2c 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="get_position_for_event",
         )
 
-        return PersistedEventPosition(
-            row["instance_name"] or "master", row["stream_ordering"]
-        )
+        return PersistedEventPosition(row[1] or "master", row[0])
 
     async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
@@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
         )
-        return RoomStreamToken(
-            topological=row["topological_ordering"], stream=row["stream_ordering"]
-        )
+        return RoomStreamToken(topological=row[1], stream=row[0])
 
     async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
         """Gets the topological token in a room after or at the given stream
@@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = self.db_pool.simple_select_one_txn(
-            txn,
-            "events",
-            keyvalues={"event_id": event_id, "room_id": room_id},
-            retcols=["stream_ordering", "topological_ordering"],
+        stream_ordering, topological_ordering = cast(
+            Tuple[int, int],
+            self.db_pool.simple_select_one_txn(
+                txn,
+                "events",
+                keyvalues={"event_id": event_id, "room_id": room_id},
+                retcols=["stream_ordering", "topological_ordering"],
+            ),
         )
 
-        # This cannot happen as `allow_none=False`.
-        assert results is not None
-
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
-            topological=results["topological_ordering"] - 1,
-            stream=results["stream_ordering"],
+            topological=topological_ordering - 1, stream=stream_ordering
         )
 
         after_token = RoomStreamToken(
-            topological=results["topological_ordering"],
-            stream=results["stream_ordering"],
+            topological=topological_ordering, stream=stream_ordering
         )
 
         rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5555b53575..64543b4d61 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
 
         Returns: the task if available, `None` otherwise
         """
-        row = await self.db_pool.simple_select_one(
-            table="scheduled_tasks",
-            keyvalues={"id": id},
-            retcols=(
-                "id",
-                "action",
-                "status",
-                "timestamp",
-                "resource_id",
-                "params",
-                "result",
-                "error",
+        row = cast(
+            Optional[ScheduledTaskRow],
+            await self.db_pool.simple_select_one(
+                table="scheduled_tasks",
+                keyvalues={"id": id},
+                retcols=(
+                    "id",
+                    "action",
+                    "status",
+                    "timestamp",
+                    "resource_id",
+                    "params",
+                    "result",
+                    "error",
+                ),
+                allow_none=True,
+                desc="get_scheduled_task",
             ),
-            allow_none=True,
-            desc="get_scheduled_task",
         )
 
-        return (
-            TaskSchedulerWorkerStore._convert_row_to_task(
-                (
-                    row["id"],
-                    row["action"],
-                    row["status"],
-                    row["timestamp"],
-                    row["resource_id"],
-                    row["params"],
-                    row["result"],
-                    row["error"],
-                )
-            )
-            if row
-            else None
-        )
+        return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
 
     async def delete_scheduled_task(self, id: str) -> None:
         """Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index fecddb4144..2d341affaa 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             txn,
             table="received_transactions",
             keyvalues={"transaction_id": transaction_id, "origin": origin},
-            retcols=(
-                "transaction_id",
-                "origin",
-                "ts",
-                "response_code",
-                "response_json",
-                "has_been_referenced",
-            ),
+            retcols=("response_code", "response_json"),
             allow_none=True,
         )
 
-        if result and result["response_code"]:
-            return result["response_code"], db_to_json(result["response_json"])
+        # If the result exists and the response code is non-0.
+        if result and result[0]:
+            return result[0], db_to_json(result[1])
 
         else:
             return None
@@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
 
         # check we have a row and retry_last_ts is not null or zero
         # (retry_last_ts can't be negative)
-        if result and result["retry_last_ts"]:
-            return DestinationRetryTimings(**result)
+        if result and result[1]:
+            return DestinationRetryTimings(
+                failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
+            )
         else:
             return None
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 8ab7c42c4a..5b164fed8e 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
             desc="get_ui_auth_session",
         )
 
-        result["clientdict"] = db_to_json(result["clientdict"])
-
-        return UIAuthSessionData(session_id, **result)
+        return UIAuthSessionData(
+            session_id,
+            clientdict=db_to_json(result[0]),
+            uri=result[1],
+            method=result[2],
+            description=result[3],
+        )
 
     async def mark_ui_auth_stage_complete(
         self,
@@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ) -> None:
         # Get the current value.
-        result = cast(
-            Dict[str, Any],
-            self.db_pool.simple_select_one_txn(
-                txn,
-                table="ui_auth_sessions",
-                keyvalues={"session_id": session_id},
-                retcols=("serverdict",),
-            ),
+        result = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="ui_auth_sessions",
+            keyvalues={"session_id": session_id},
+            retcol="serverdict",
         )
 
         # Update it and add it back to the database.
-        serverdict = db_to_json(result["serverdict"])
+        serverdict = db_to_json(result)
         serverdict[key] = value
 
         self.db_pool.simple_update_one_txn(
@@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
         Raises:
             StoreError if the session cannot be found.
         """
-        result = await self.db_pool.simple_select_one(
+        result = await self.db_pool.simple_select_one_onecol(
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
-            retcols=("serverdict",),
+            retcol="serverdict",
             desc="get_ui_auth_session_data",
         )
 
-        serverdict = db_to_json(result["serverdict"])
+        serverdict = db_to_json(result)
 
         return serverdict.get(key, default)
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f5d68b63..1a38f3d785 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -20,7 +20,6 @@ from typing import (
     Collection,
     Iterable,
     List,
-    Mapping,
     Optional,
     Sequence,
     Set,
@@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
-    async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
-        return await self.db_pool.simple_select_one(
-            table="user_directory",
-            keyvalues={"user_id": user_id},
-            retcols=("display_name", "avatar_url"),
-            allow_none=True,
-            desc="get_user_in_directory",
+    async def _get_user_in_directory(
+        self, user_id: str
+    ) -> Optional[Tuple[Optional[str], Optional[str]]]:
+        """
+        Fetch the user information in the user directory.
+
+        Returns:
+            None if the user is unknown, otherwise a tuple of display name and
+            avatar URL (both of which may be None).
+        """
+        return cast(
+            Optional[Tuple[Optional[str], Optional[str]]],
+            await self.db_pool.simple_select_one(
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+                retcols=("display_name", "avatar_url"),
+                allow_none=True,
+                desc="get_user_in_directory",
+            ),
         )
 
     async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 0f9c550b27..2c3151526d 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -492,7 +492,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             conn.rollback()
             if isinstance(self.database_engine, PostgresEngine):
                 # postgres insists on autocommit for the index
-                conn.set_session(autocommit=True)
+                conn.engine.attempt_to_set_autocommit(conn.conn, True)
                 try:
                     txn = conn.cursor()
                     txn.execute(
@@ -501,7 +501,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                     )
                     txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
                 finally:
-                    conn.set_session(autocommit=False)
+                    conn.engine.attempt_to_set_autocommit(conn.conn, False)
             else:
                 txn = conn.cursor()
                 txn.execute(
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6309363217..ec4c4041b7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -38,7 +38,8 @@ class PostgresEngine(
         super().__init__(psycopg2, database_config)
         psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
-        # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+        # Disables passing `bytes` to txn.execute, c.f.
+        # https://github.com/matrix-org/synapse/issues/6186. If you do
         # actually want to use bytes than wrap it in `bytearray`.
         def _disable_bytes_adapter(_: bytes) -> NoReturn:
             raise Exception("Passing bytes to DB is disabled.")
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 158b528dce..03e5a0f55d 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -109,7 +109,8 @@ Changes in SCHEMA_VERSION = 78
 
 Changes in SCHEMA_VERSION = 79
     - Add tables to handle in DB read-write locks.
-    - Add some mitigations for a painful race between foreground and background updates, cf #15677.
+    - Add some mitigations for a painful race between foreground and background updates, cf
+      https://github.com/matrix-org/synapse/issues/15677.
 
 Changes in SCHEMA_VERSION = 80
     - The event_txn_id_device_id is always written to for new events.
diff --git a/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql
index b062ec840c..f713e42aa0 100644
--- a/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/schema/main/delta/54/delete_forward_extremities.sql
@@ -14,7 +14,7 @@
  */
 
 -- Start a background job to cleanup extremities that were incorrectly added
--- by bug #5269.
+-- by bug https://github.com/matrix-org/synapse/issues/5269.
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('delete_soft_failed_extremities', '{}');
 
diff --git a/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql
index aeb17813d3..246c3359f7 100644
--- a/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql
+++ b/synapse/storage/schema/main/delta/56/remove_tombstoned_rooms_from_directory.sql
@@ -13,6 +13,7 @@
  * limitations under the License.
  */
 
--- Now that #6232 is a thing, we can remove old rooms from the directory.
+-- Now that https://github.com/matrix-org/synapse/pull/6232 is a thing, we can
+-- remove old rooms from the directory.
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('remove_tombstoned_rooms_from_directory', '{}');
diff --git a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
index aed79635b2..31a61defa7 100644
--- a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
+++ b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql
@@ -13,7 +13,8 @@
  * limitations under the License.
  */
 
--- Clean up left over rows from bug #11833, which was fixed in #12770.
+-- Clean up left over rows from bug https://github.com/matrix-org/synapse/issues/11833,
+-- which was fixed in https://github.com/matrix-org/synapse/pull/12770.
 DELETE FROM federation_inbound_events_staging WHERE room_id not in (
     SELECT room_id FROM rooms
 );
diff --git a/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql b/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql
new file mode 100644
index 0000000000..b74bdd71fa
--- /dev/null
+++ b/synapse/storage/schema/main/delta/83/05_cross_signing_key_update_grant.sql
@@ -0,0 +1,15 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+ALTER TABLE e2e_cross_signing_keys ADD COLUMN updatable_without_uia_before_ms bigint DEFAULT NULL;
\ No newline at end of file
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9c3eafb562..bd3c81827f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         next_id = self._load_next_id_txn(txn)
 
-        txn.call_after(self._mark_id_as_finished, next_id)
-        txn.call_on_exception(self._mark_id_as_finished, next_id)
+        txn.call_after(self._mark_ids_as_finished, [next_id])
+        txn.call_on_exception(self._mark_ids_as_finished, [next_id])
         txn.call_after(self._notifier.notify_replication)
 
         # Update the `stream_positions` table with newly updated stream
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         return self._return_factor * next_id
 
-    def _mark_id_as_finished(self, next_id: int) -> None:
-        """The ID has finished being processed so we should advance the
+    def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next_txn(txn)
+            # ... persist event ...
+        """
+
+        # If we have a list of instances that are allowed to write to this
+        # stream, make sure we're in it.
+        if self._writers and self._instance_name not in self._writers:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
+        next_ids = self._load_next_mult_id_txn(txn, n)
+
+        txn.call_after(self._mark_ids_as_finished, next_ids)
+        txn.call_on_exception(self._mark_ids_as_finished, next_ids)
+        txn.call_after(self._notifier.notify_replication)
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
+        return [self._return_factor * next_id for next_id in next_ids]
+
+    def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
+        """These IDs have finished being processed so we should advance the
         current position if possible.
         """
 
         with self._lock:
-            self._unfinished_ids.discard(next_id)
-            self._finished_ids.add(next_id)
+            self._unfinished_ids.difference_update(next_ids)
+            self._finished_ids.update(next_ids)
 
             new_cur: Optional[int] = None
 
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                     curr, new_cur, self._max_position_of_local_instance
                 )
 
-            self._add_persisted_position(next_id)
+            # TODO Can we call this for just the last position or somehow batch
+            # _add_persisted_position.
+            for next_id in next_ids:
+                self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
         return self.get_persisted_upto_position()
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
         exc: Optional[BaseException],
         tb: Optional[TracebackType],
     ) -> bool:
-        for i in self.stream_ids:
-            self.id_gen._mark_id_as_finished(i)
+        self.id_gen._mark_ids_as_finished(self.stream_ids)
 
         self.notifier.notify_replication()
 
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9f3b8741c1..8d9df352b2 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -93,7 +93,7 @@ class Clock:
 
     _reactor: IReactorTime = attr.ib()
 
-    @defer.inlineCallbacks  # type: ignore[arg-type]  # Issue in Twisted's type annotations
+    @defer.inlineCallbacks
     def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
         d: defer.Deferred[float] = defer.Deferred()
         with context.PreserveLoggingContext():
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 0cbeb0c365..8a55e4e41d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -345,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation(
 T1 = TypeVar("T1")
 T2 = TypeVar("T2")
 T3 = TypeVar("T3")
+T4 = TypeVar("T4")
 
 
 @overload
@@ -380,6 +381,19 @@ def gather_results(
     ...
 
 
+@overload
+def gather_results(
+    deferredList: Tuple[
+        "defer.Deferred[T1]",
+        "defer.Deferred[T2]",
+        "defer.Deferred[T3]",
+        "defer.Deferred[T4]",
+    ],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]":
+    ...
+
+
 def gather_results(  # type: ignore[misc]
     deferredList: Tuple["defer.Deferred[T1]", ...],
     consumeErrors: bool = False,
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index f7cead9e12..6f008734a0 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -189,7 +189,8 @@ def check_requirements(extra: Optional[str] = None) -> None:
                 errors.append(_not_installed(requirement, extra))
         else:
             if dist.version is None:
-                # This shouldn't happen---it suggests a borked virtualenv. (See #12223)
+                # This shouldn't happen---it suggests a borked virtualenv. (See
+                # https://github.com/matrix-org/synapse/issues/12223)
                 # Try to give a vaguely helpful error message anyway.
                 # Type-ignore: the annotations don't reflect reality: see
                 #     https://github.com/python/typeshed/issues/7513
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index a0efb96d3b..f4c0194af0 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -135,3 +135,54 @@ def sorted_topologically(
                 degree_map[edge] -= 1
                 if degree_map[edge] == 0:
                     heapq.heappush(zero_degree, edge)
+
+
+def sorted_topologically_batched(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> Generator[Collection[T], None, None]:
+    r"""Walk the graph topologically, returning batches of nodes where all nodes
+    that references it have been previously returned.
+
+    For example, given the following graph:
+
+         A
+        / \
+       B   C
+        \ /
+         D
+
+    This function will return: `[[A], [B, C], [D]]`.
+
+    This function is useful for e.g. batch persisting events in an auth chain,
+    where we can only persist an event if all its auth events have already been
+    persisted.
+    """
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph: Dict[T, Set[T]] = {}
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in set(edges):
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+    while zero_degree:
+        new_zero_degree = []
+        for node in zero_degree:
+            for edge in reverse_graph.get(node, []):
+                if edge in degree_map:
+                    degree_map[edge] -= 1
+                    if degree_map[edge] == 0:
+                        new_zero_degree.append(edge)
+
+        yield zero_degree
+        zero_degree = new_zero_degree
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index caf13b3474..8c2df233d3 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -71,7 +71,7 @@ class TaskScheduler:
     # Time before a complete or failed task is deleted from the DB
     KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000  # 1 week
     # Maximum number of tasks that can run at the same time
-    MAX_CONCURRENT_RUNNING_TASKS = 10
+    MAX_CONCURRENT_RUNNING_TASKS = 5
     # Time from the last task update after which we will log a warning
     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
 
@@ -193,7 +193,7 @@ class TaskScheduler:
         result: Optional[JsonMapping] = None,
         error: Optional[str] = None,
     ) -> bool:
-        """Update some task associated values. This is exposed publically so it can
+        """Update some task associated values. This is exposed publicly so it can
         be used inside task functions, mainly to update the result and be able to
         resume a task at a specific step after a restart of synapse.