summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py4
-rw-r--r--synapse/config/key.py36
-rw-r--r--synapse/config/metrics.py6
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/config/tls.py2
-rw-r--r--synapse/events/utils.py11
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/e2e_keys.py12
-rw-r--r--synapse/handlers/message.py24
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/room_member.py3
-rw-r--r--synapse/handlers/sync.py55
-rw-r--r--synapse/http/__init__.py6
-rw-r--r--synapse/http/additional_resource.py12
-rw-r--r--synapse/http/client.py15
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/http/server.py90
-rw-r--r--synapse/http/servlet.py50
-rw-r--r--synapse/http/site.py8
-rw-r--r--synapse/push/push_tools.py3
-rw-r--r--synapse/push/pusherpool.py5
-rw-r--r--synapse/replication/slave/storage/_base.py9
-rw-r--r--synapse/replication/slave/storage/client_ips.py9
-rw-r--r--synapse/replication/slave/storage/devices.py9
-rw-r--r--synapse/replication/slave/storage/events.py9
-rw-r--r--synapse/replication/slave/storage/filtering.py9
-rw-r--r--synapse/replication/slave/storage/groups.py9
-rw-r--r--synapse/rest/admin/__init__.py4
-rw-r--r--synapse/rest/admin/background_updates.py16
-rw-r--r--synapse/rest/admin/devices.py22
-rw-r--r--synapse/rest/admin/event_reports.py2
-rw-r--r--synapse/rest/admin/federation.py2
-rw-r--r--synapse/rest/admin/groups.py2
-rw-r--r--synapse/rest/admin/media.py60
-rw-r--r--synapse/rest/admin/registration_tokens.py3
-rw-r--r--synapse/rest/admin/rooms.py70
-rw-r--r--synapse/rest/admin/server_notice_servlet.py4
-rw-r--r--synapse/rest/admin/statistics.py22
-rw-r--r--synapse/rest/admin/username_available.py2
-rw-r--r--synapse/rest/admin/users.py51
-rw-r--r--synapse/rest/client/devices.py6
-rw-r--r--synapse/rest/client/notifications.py3
-rw-r--r--synapse/rest/client/read_marker.py6
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/relations.py7
-rw-r--r--synapse/rest/client/room.py2
-rw-r--r--synapse/rest/client/sync.py3
-rw-r--r--synapse/rest/key/v2/local_key_resource.py4
-rw-r--r--synapse/rest/media/v1/oembed.py5
-rw-r--r--synapse/rest/media/v1/preview_html.py397
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py383
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/database.py111
-rw-r--r--synapse/storage/databases/main/__init__.py12
-rw-r--r--synapse/storage/databases/main/account_data.py93
-rw-r--r--synapse/storage/databases/main/appservice.py10
-rw-r--r--synapse/storage/databases/main/cache.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py13
-rw-r--r--synapse/storage/databases/main/client_ips.py22
-rw-r--r--synapse/storage/databases/main/deviceinbox.py7
-rw-r--r--synapse/storage/databases/main/devices.py36
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py211
-rw-r--r--synapse/storage/databases/main/event_federation.py26
-rw-r--r--synapse/storage/databases/main/event_push_actions.py20
-rw-r--r--synapse/storage/databases/main/events.py135
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py8
-rw-r--r--synapse/storage/databases/main/group_server.py10
-rw-r--r--synapse/storage/databases/main/lock.py14
-rw-r--r--synapse/storage/databases/main/metrics.py27
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py27
-rw-r--r--synapse/storage/databases/main/presence.py11
-rw-r--r--synapse/storage/databases/main/push_rule.py9
-rw-r--r--synapse/storage/databases/main/receipts.py112
-rw-r--r--synapse/storage/databases/main/registration.py7
-rw-r--r--synapse/storage/databases/main/relations.py38
-rw-r--r--synapse/storage/databases/main/room.py29
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/search.py20
-rw-r--r--synapse/storage/databases/main/state.py27
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py11
-rw-r--r--synapse/storage/databases/main/stream.py8
-rw-r--r--synapse/storage/databases/main/tags.py22
-rw-r--r--synapse/storage/databases/main/transactions.py13
-rw-r--r--synapse/storage/databases/main/user_directory.py11
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/util/id_generators.py6
88 files changed, 1636 insertions, 1025 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f7d29b4319..52c083a20b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -253,5 +253,9 @@ class GuestAccess:
     FORBIDDEN: Final = "forbidden"
 
 
+class ReceiptTypes:
+    READ: Final = "m.read"
+
+
 class ReadReceiptEventFields:
     MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 035ee2416b..ee83c6c06b 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,12 +16,14 @@
 import hashlib
 import logging
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Iterator, List, Optional
 
 import attr
 import jsonschema
 from signedjson.key import (
     NACL_ED25519,
+    SigningKey,
+    VerifyKey,
     decode_signing_key_base64,
     decode_verify_key_bytes,
     generate_signing_key,
@@ -31,6 +33,7 @@ from signedjson.key import (
 )
 from unpaddedbase64 import decode_base64
 
+from synapse.types import JsonDict
 from synapse.util.stringutils import random_string, random_string_with_symbols
 
 from ._base import Config, ConfigError
@@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True, auto_attribs=True)
 class TrustedKeyServer:
-    # string: name of the server.
-    server_name = attr.ib()
+    # name of the server.
+    server_name: str
 
-    # dict[str,VerifyKey]|None: map from key id to key object, or None to disable
-    # signature verification.
-    verify_keys = attr.ib(default=None)
+    # map from key id to key object, or None to disable signature verification.
+    verify_keys: Optional[Dict[str, VerifyKey]] = None
 
 
 class KeyConfig(Config):
@@ -279,15 +281,15 @@ class KeyConfig(Config):
             % locals()
         )
 
-    def read_signing_keys(self, signing_key_path, name):
+    def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
         """Read the signing keys in the given path.
 
         Args:
-            signing_key_path (str)
-            name (str): Associated config key name
+            signing_key_path
+            name: Associated config key name
 
         Returns:
-            list[SigningKey]
+            The signing keys read from the given path.
         """
 
         signing_keys = self.read_file(signing_key_path, name)
@@ -296,7 +298,9 @@ class KeyConfig(Config):
         except Exception as e:
             raise ConfigError("Error reading %s: %s" % (name, str(e)))
 
-    def read_old_signing_keys(self, old_signing_keys):
+    def read_old_signing_keys(
+        self, old_signing_keys: Optional[JsonDict]
+    ) -> Dict[str, VerifyKey]:
         if old_signing_keys is None:
             return {}
         keys = {}
@@ -340,7 +344,7 @@ class KeyConfig(Config):
                     write_signing_keys(signing_key_file, (key,))
 
 
-def _perspectives_to_key_servers(config):
+def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]:
     """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers'
 
     Returns an iterable of entries to add to trusted_key_servers.
@@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = {
 }
 
 
-def _parse_key_servers(key_servers, federation_verify_certificates):
+def _parse_key_servers(
+    key_servers: List[Any], federation_verify_certificates: bool
+) -> Iterator[TrustedKeyServer]:
     try:
         jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
     except jsonschema.ValidationError as e:
@@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
         yield result
 
 
-def _assert_keyserver_has_verify_keys(trusted_key_server):
+def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None:
     if not trusted_key_server.verify_keys:
         raise ConfigError(INSECURE_NOTARY_ERROR)
 
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 7ac82edb0e..1cc26e7578 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -22,10 +22,12 @@ from ._base import Config, ConfigError
 
 @attr.s
 class MetricsFlags:
-    known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
+    known_servers: bool = attr.ib(
+        default=False, validator=attr.validators.instance_of(bool)
+    )
 
     @classmethod
-    def all_off(cls):
+    def all_off(cls) -> "MetricsFlags":
         """
         Instantiate the flags with all options set to off.
         """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ba5b954263..1de2dea9b0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1257,7 +1257,7 @@ class ServerConfig(Config):
             help="Turn on the twisted telnet manhole service on the given port.",
         )
 
-    def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
+    def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
         """Reads the three durations for the GC min interval option, returning seconds."""
         if durations is None:
             return None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4ca111618f..ffb316e4c0 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -132,7 +132,7 @@ class TlsConfig(Config):
         self.tls_certificate: Optional[crypto.X509] = None
         self.tls_private_key: Optional[crypto.PKey] = None
 
-    def read_certificate_from_disk(self):
+    def read_certificate_from_disk(self) -> None:
         """
         Read the certificates and private key from disk.
         """
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 84ef69df67..3da432c1df 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -454,23 +454,26 @@ class EventClientSerializer:
                 return
 
         event_id = event.event_id
+        room_id = event.room_id
 
         # The bundled aggregations to include.
         aggregations = {}
 
-        annotations = await self.store.get_aggregation_groups_for_event(event_id)
+        annotations = await self.store.get_aggregation_groups_for_event(
+            event_id, room_id
+        )
         if annotations.chunk:
             aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
 
         references = await self.store.get_relations_for_event(
-            event_id, RelationTypes.REFERENCE, direction="f"
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
             aggregations[RelationTypes.REFERENCE] = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
-            edit = await self.store.get_applicable_edit(event_id)
+            edit = await self.store.get_applicable_edit(event_id, room_id)
 
         if edit:
             # If there is an edit replace the content, preserving existing
@@ -503,7 +506,7 @@ class EventClientSerializer:
             (
                 thread_count,
                 latest_thread_event,
-            ) = await self.store.get_thread_summary(event_id)
+            ) = await self.store.get_thread_summary(event_id, room_id)
             if latest_thread_event:
                 aggregations[RelationTypes.THREAD] = {
                     # Don't bundle aggregations as this could recurse forever.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 61607cf2ba..84724b207c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -997,9 +997,7 @@ class AuthHandler:
         # really don't want is active access_tokens without a record of the
         # device, so we double-check it here.
         if device_id is not None:
-            try:
-                await self.store.get_device(user_id, device_id)
-            except StoreError:
+            if await self.store.get_device(user_id, device_id) is None:
                 await self.store.delete_access_token(access_token)
                 raise StoreError(400, "Login raced against device deletion")
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 82ee11e921..7665425232 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -106,10 +106,10 @@ class DeviceWorkerHandler:
         Raises:
             errors.NotFoundError: if the device was not found
         """
-        try:
-            device = await self.store.get_device(user_id, device_id)
-        except errors.StoreError:
-            raise errors.NotFoundError
+        device = await self.store.get_device(user_id, device_id)
+        if device is None:
+            raise errors.NotFoundError()
+
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
@@ -602,6 +602,8 @@ class DeviceHandler(DeviceWorkerHandler):
             access_token, device_id
         )
         old_device = await self.store.get_device(user_id, old_device_id)
+        if old_device is None:
+            raise errors.NotFoundError()
         await self.store.update_device(user_id, device_id, old_device["display_name"])
         # can't call self.delete_device because that will clobber the
         # access token so call the storage layer directly
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 60c11e3d21..14360b4e40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -65,8 +65,12 @@ class E2eKeysHandler:
         else:
             # Only register this edu handler on master as it requires writing
             # device updates to the db
-            #
-            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            federation_registry.register_edu_handler(
+                "m.signing_key_update",
+                self._edu_updater.incoming_signing_key_update,
+            )
+            # also handle the unstable version
+            # FIXME: remove this when enough servers have upgraded
             federation_registry.register_edu_handler(
                 "org.matrix.signing_key_update",
                 self._edu_updater.incoming_signing_key_update,
@@ -576,7 +580,9 @@ class E2eKeysHandler:
             log_kv(
                 {"message": "Did not update one_time_keys", "reason": "no keys given"}
             )
-        fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+        fallback_keys = keys.get("fallback_keys") or keys.get(
+            "org.matrix.msc2732.fallback_keys"
+        )
         if fallback_keys and isinstance(fallback_keys, dict):
             log_kv(
                 {
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 87f671708c..38409fef38 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -496,6 +496,7 @@ class EventCreationHandler:
         require_consent: bool = True,
         outlier: bool = False,
         historical: bool = False,
+        allow_no_prev_events: bool = False,
         depth: Optional[int] = None,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -607,6 +608,7 @@ class EventCreationHandler:
             prev_event_ids=prev_event_ids,
             auth_event_ids=auth_event_ids,
             depth=depth,
+            allow_no_prev_events=allow_no_prev_events,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -882,6 +884,7 @@ class EventCreationHandler:
         prev_event_ids: Optional[List[str]] = None,
         auth_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        allow_no_prev_events: bool = False,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -912,6 +915,7 @@ class EventCreationHandler:
         full_state_ids_at_event = None
         if auth_event_ids is not None:
             # If auth events are provided, prev events must be also.
+            # prev_event_ids could be an empty array though.
             assert prev_event_ids is not None
 
             # Copy the full auth state before it stripped down
@@ -943,14 +947,22 @@ class EventCreationHandler:
         else:
             prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
 
-        # we now ought to have some prev_events (unless it's a create event).
-        #
-        # do a quick sanity check here, rather than waiting until we've created the
+        # Do a quick sanity check here, rather than waiting until we've created the
         # event and then try to auth it (which fails with a somewhat confusing "No
         # create event in auth events")
-        assert (
-            builder.type == EventTypes.Create or len(prev_event_ids) > 0
-        ), "Attempting to create an event with no prev_events"
+        if allow_no_prev_events:
+            # We allow events with no `prev_events` but it better have some `auth_events`
+            assert (
+                builder.type == EventTypes.Create
+                # Allow an event to have empty list of prev_event_ids
+                # only if it has auth_event_ids.
+                or auth_event_ids
+            ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
+        else:
+            # we now ought to have some prev_events (unless it's a create event).
+            assert (
+                builder.type == EventTypes.Create or prev_event_ids
+            ), "Attempting to create a non-m.room.create event with no prev_events"
 
         event = await builder.build(
             prev_event_ids=prev_event_ids,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4911a11535..5cb1ff749d 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
             for event_id in content.keys():
                 event_content = content.get(event_id, {})
-                m_read = event_content.get("m.read", {})
+                m_read = event_content.get(ReceiptTypes.READ, {})
 
                 # If m_read is missing copy over the original event_content as there is nothing to process here
                 if not m_read:
@@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
                 # Set new users unless empty
                 if len(new_users.keys()) > 0:
-                    new_event["content"][event_id] = {"m.read": new_users}
+                    new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
 
             # Append new_event to visible_events unless empty
             if len(new_event["content"].keys()) > 0:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a6dbff637f..447e3ce571 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -658,7 +658,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if block_invite:
                 raise SynapseError(403, "Invites have been disabled on this server")
 
-        if prev_event_ids:
+        # An empty prev_events list is allowed as long as the auth_event_ids are present
+        if prev_event_ids is not None:
             return await self._local_membership_update(
                 requester=requester,
                 target=target,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3f..bcd10cbb30 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,7 +28,7 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1046,7 +1046,7 @@ class SyncHandler:
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
-                receipt_type="m.read",
+                receipt_type=ReceiptTypes.READ,
             )
 
             notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1662,20 +1662,20 @@ class SyncHandler:
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
-        Ideally, we want to report all events whose stream ordering `s` lies in the
-        range `since_token < s <= now_token`, where the two tokens are read from the
-        sync_result_builder.
+        This function is a first pass at generating the rooms part of the sync response.
+        It determines which rooms have changed during the sync period, and categorises
+        them into four buckets: "knock", "invite", "join" and "leave".
 
-        If there are too many events in that range to report, things get complicated.
-        In this situation we return a truncated list of the most recent events, and
-        indicate in the response that there is a "gap" of omitted events. Additionally:
+        1. Finds all membership changes for the user in the sync period (from
+           `since_token` up to `now_token`).
+        2. Uses those to place the room in one of the four categories above.
+        3. Builds a `_RoomChanges` struct to record this, and return that struct.
 
-        - we include a "state_delta", to describe the changes in state over the gap,
-        - we include all membership events applying to the user making the request,
-          even those in the gap.
-
-        See the spec for the rationale:
-            https://spec.matrix.org/v1.1/client-server-api/#syncing
+        For rooms classified as "knock", "invite" or "leave", we just need to report
+        a single membership event in the eventual /sync response. For "join" we need
+        to fetch additional non-membership events, e.g. messages in the room. That is
+        more complicated, so instead we report an intermediary `RoomSyncResultBuilder`
+        struct, and leave the additional work to `_generate_room_entry`.
 
         The sync_result_builder is not modified by this function.
         """
@@ -1686,16 +1686,6 @@ class SyncHandler:
 
         assert since_token
 
-        # The spec
-        #     https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
-        # notes that membership events need special consideration:
-        #
-        # > When a sync is limited, the server MUST return membership events for events
-        # > in the gap (between since and the start of the returned timeline), regardless
-        # > as to whether or not they are redundant.
-        #
-        # We fetch such events here, but we only seem to use them for categorising rooms
-        # as newly joined, newly left, invited or knocked.
         # TODO: we've already called this function and ran this query in
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
@@ -2009,6 +1999,23 @@ class SyncHandler:
         """Populates the `joined` and `archived` section of `sync_result_builder`
         based on the `room_builder`.
 
+        Ideally, we want to report all events whose stream ordering `s` lies in the
+        range `since_token < s <= now_token`, where the two tokens are read from the
+        sync_result_builder.
+
+        If there are too many events in that range to report, things get complicated.
+        In this situation we return a truncated list of the most recent events, and
+        indicate in the response that there is a "gap" of omitted events. Lots of this
+        is handled in `_load_filtered_recents`, but some of is handled in this method.
+
+        Additionally:
+        - we include a "state_delta", to describe the changes in state over the gap,
+        - we include all membership events applying to the user making the request,
+          even those in the gap.
+
+        See the spec for the rationale:
+            https://spec.matrix.org/v1.1/client-server-api/#syncing
+
         Args:
             sync_result_builder
             ignored_users: Set of users ignored by user.
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 578fc48ef4..efecb089c1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
 class RequestTimedOutError(SynapseError):
     """Exception representing timeout of an outbound request"""
 
-    def __init__(self, msg):
+    def __init__(self, msg: str):
         super().__init__(504, msg)
 
 
@@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
 CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
 
 
-def redact_uri(uri):
+def redact_uri(uri: str) -> str:
     """Strips sensitive information from the uri replaces with <redacted>"""
     uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
     return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
     https://twistedmatrix.com/trac/ticket/6528
     """
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         try:
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 9a2684aca4..6a9f6635d2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
     and exception handling.
     """
 
-    def __init__(self, hs: "HomeServer", handler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
+    ):
         """Initialise AdditionalResource
 
         The ``handler`` should return a deferred which completes when it has
@@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
         super().__init__()
         self._handler = handler
 
-    def _async_render(self, request: Request):
+    async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
         # Cheekily pass the result straight through, so we don't need to worry
         # if its an awaitable or not.
-        return self._handler(request)
+        return await self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b5a2d333a6..fbbeceabeb 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO
 from typing import (
     TYPE_CHECKING,
@@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
-                e = SynapseError(403, "IP address blocked by IP blacklist entry")
+                e = SynapseError(
+                    HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
+                )
                 return defer.fail(Failure(e))
 
         return self._agent.request(
@@ -719,7 +722,9 @@ class SimpleHttpClient:
 
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
-            raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
+            )
 
         # TODO: if our Content-Type is HTML or something, just read the first
         # N bytes into RAM rather than saving it all to disk only to read it
@@ -731,12 +736,14 @@ class SimpleHttpClient:
             )
         except BodyExceededMaxSize:
             raise SynapseError(
-                502,
+                HTTPStatus.BAD_GATEWAY,
                 "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
             )
         except Exception as e:
-            raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
+            ) from e
 
         return (
             length,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 203d723d41..deedde0b5b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,6 +19,7 @@ import random
 import sys
 import typing
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
@@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
                 request.destination,
                 msg,
             )
-            raise SynapseError(502, msg, Codes.TOO_LARGE)
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
         except defer.TimeoutError as e:
             logger.warning(
                 "{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 91badb0b0a..4fd5660a08 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    NoReturn,
     Optional,
     Pattern,
     Tuple,
@@ -170,7 +171,9 @@ def return_html_error(
     respond_with_html(request, code, body)
 
 
-def wrap_async_request_handler(h):
+def wrap_async_request_handler(
+    h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
+) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
     """Wraps an async request handler so that it calls request.processing.
 
     This helps ensure that work done by the request handler after the request is completed
@@ -183,7 +186,9 @@ def wrap_async_request_handler(h):
     logged until the deferred completes.
     """
 
-    async def wrapped_async_request_handler(self, request):
+    async def wrapped_async_request_handler(
+        self: "_AsyncResource", request: SynapseRequest
+    ) -> None:
         with request.processing():
             await h(self, request)
 
@@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             context from the request the servlet is handling.
     """
 
-    def __init__(self, extract_context=False):
+    def __init__(self, extract_context: bool = False):
         super().__init__()
 
         self._extract_context = extract_context
 
-    def render(self, request):
+    def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
         defer.ensureDeferred(self._async_render_wrapper(request))
         return NOT_DONE_YET
 
     @wrap_async_request_handler
-    async def _async_render_wrapper(self, request: SynapseRequest):
+    async def _async_render_wrapper(self, request: SynapseRequest) -> None:
         """This is a wrapper that delegates to `_async_render` and handles
         exceptions, return values, metrics, etc.
         """
@@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             f = failure.Failure()
             self._send_error_response(f, request)
 
-    async def _async_render(self, request: Request):
+    async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
         """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
         no appropriate method exists. Can be overridden in sub classes for
         different routing.
@@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource):
     formatting responses and errors as JSON.
     """
 
-    def __init__(self, canonical_json=False, extract_context=False):
+    def __init__(self, canonical_json: bool = False, extract_context: bool = False):
         super().__init__(extract_context)
         self.canonical_json = canonical_json
 
@@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # TODO: Only enable CORS for the requests that need it.
         respond_with_json(
@@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource):
 
     isLeaf = True
 
-    def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        canonical_json: bool = True,
+        extract_context: bool = False,
+    ):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
         self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
         self.hs = hs
 
-    def register_paths(self, method, path_patterns, callback, servlet_classname):
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
         """
         Registers a request handler against a regular expression. Later request URLs are
         checked against these regular expressions in order to identify an appropriate
         handler for that request.
 
         Args:
-            method (str): GET, POST etc
+            method: GET, POST etc
 
-            path_patterns (Iterable[str]): A list of regular expressions to which
-                the request URLs are compared.
+            path_patterns: A list of regular expressions to which the request
+                URLs are compared.
 
-            callback (function): The handler for the request. Usually a Servlet
+            callback: The handler for the request. Usually a Servlet
 
-            servlet_classname (str): The name of the handler to be used in prometheus
+            servlet_classname: The name of the handler to be used in prometheus
                 and opentracing logs.
         """
-        method = method.encode("utf-8")  # method is bytes on py3
+        method_bytes = method.encode("utf-8")
 
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
-            self.path_regexs.setdefault(method, []).append(
+            self.path_regexs.setdefault(method_bytes, []).append(
                 _PathEntry(path_pattern, callback, servlet_classname)
             )
 
@@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource):
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
         return _unrecognised_request_handler, "unrecognised_request_handler", {}
 
-    async def _async_render(self, request):
+    async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
         # Make sure we have an appropriate name for this handler in prometheus
@@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # We expect to get bytes for us to write
         assert isinstance(response_object, bytes)
@@ -492,12 +508,12 @@ class StaticResource(File):
     Differs from the File resource by adding clickjacking protection.
     """
 
-    def render_GET(self, request: Request):
+    def render_GET(self, request: Request) -> bytes:
         set_clickjacking_protection_headers(request)
         return super().render_GET(request)
 
 
-def _unrecognised_request_handler(request):
+def _unrecognised_request_handler(request: Request) -> NoReturn:
     """Request handler for unrecognised requests
 
     This is a request handler suitable for return from
@@ -505,7 +521,7 @@ def _unrecognised_request_handler(request):
     UnrecognizedRequestError.
 
     Args:
-        request (twisted.web.http.Request):
+        request: Unused, but passed in to match the signature of ServletCallback.
     """
     raise UnrecognizedRequestError()
 
@@ -513,14 +529,14 @@ def _unrecognised_request_handler(request):
 class RootRedirect(resource.Resource):
     """Redirects the root '/' path to another path."""
 
-    def __init__(self, path):
+    def __init__(self, path: str):
         resource.Resource.__init__(self)
         self.url = path
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         return redirectTo(self.url.encode("ascii"), request)
 
-    def getChild(self, name, request):
+    def getChild(self, name: str, request: Request) -> resource.Resource:
         if len(name) == 0:
             return self  # select ourselves as the child to render
         return resource.Resource.getChild(self, name, request)
@@ -529,7 +545,7 @@ class RootRedirect(resource.Resource):
 class OptionsResource(resource.Resource):
     """Responds to OPTION requests for itself and all children."""
 
-    def render_OPTIONS(self, request):
+    def render_OPTIONS(self, request: Request) -> bytes:
         request.setResponseCode(204)
         request.setHeader(b"Content-Length", b"0")
 
@@ -537,7 +553,7 @@ class OptionsResource(resource.Resource):
 
         return b""
 
-    def getChildWithDefault(self, path, request):
+    def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
         if request.method == b"OPTIONS":
             return self  # select ourselves as the child to render
         return resource.Resource.getChildWithDefault(self, path, request)
@@ -649,7 +665,7 @@ def respond_with_json(
     json_object: Any,
     send_cors: bool = False,
     canonical_json: bool = True,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -696,7 +712,7 @@ def respond_with_json_bytes(
     code: int,
     json_bytes: bytes,
     send_cors: bool = False,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -713,7 +729,7 @@ def respond_with_json_bytes(
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"application/json")
@@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread(
     request: SynapseRequest,
     json_encoder: Callable[[Any], bytes],
     json_object: Any,
-):
+) -> None:
     """Encodes the given JSON object on a thread and then writes it to the
     request.
 
@@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
     _ByteProducer(request, bytes_generator)
 
 
-def set_cors_headers(request: Request):
+def set_cors_headers(request: Request) -> None:
     """Set the CORS headers so that javascript running in a web browsers can
     use this API
 
@@ -790,14 +806,14 @@ def set_cors_headers(request: Request):
     )
 
 
-def respond_with_html(request: Request, code: int, html: str):
+def respond_with_html(request: Request, code: int, html: str) -> None:
     """
     Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
     """
     respond_with_html_bytes(request, code, html.encode("utf-8"))
 
 
-def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
+def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
     """
     Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
 
@@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
     finish_request(request)
 
 
-def set_clickjacking_protection_headers(request: Request):
+def set_clickjacking_protection_headers(request: Request) -> None:
     """
     Set headers to guard against clickjacking of embedded content.
 
@@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
     finish_request(request)
 
 
-def finish_request(request: Request):
+def finish_request(request: Request) -> None:
     """Finish writing the response to the request.
 
     Twisted throws a RuntimeException if the connection closed before the
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6dd9b9ad03..4ff840ca0e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,6 +14,7 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Iterable,
@@ -30,6 +31,7 @@ from typing_extensions import Literal
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.types import JsonDict, RoomAlias, RoomID
 from synapse.util import json_decoder
 
@@ -137,11 +139,15 @@ def parse_integer_from_args(
             return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -246,11 +252,15 @@ def parse_boolean_from_args(
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
             ) % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -313,7 +323,7 @@ def parse_bytes_from_args(
         return args[name_bytes][0]
     elif required:
         message = "Missing string query parameter %s" % (name,)
-        raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
 
     return default
 
@@ -407,14 +417,16 @@ def _parse_string_value(
     try:
         value_str = value.decode(encoding)
     except ValueError:
-        raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
+        )
 
     if allowed_values is not None and value_str not in allowed_values:
         message = "Query parameter %r must be one of [%s]" % (
             name,
             ", ".join(repr(v) for v in allowed_values),
         )
-        raise SynapseError(400, message)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
     else:
         return value_str
 
@@ -510,7 +522,9 @@ def parse_strings_from_args(
     else:
         if required:
             message = "Missing string query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
 
         return default
 
@@ -638,7 +652,7 @@ def parse_json_value_from_request(
     try:
         content_bytes = request.content.read()  # type: ignore
     except Exception:
-        raise SynapseError(400, "Error reading JSON content.")
+        raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
 
     if not content_bytes and allow_empty_body:
         return None
@@ -647,7 +661,9 @@ def parse_json_value_from_request(
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
+        )
 
     return content
 
@@ -673,7 +689,7 @@ def parse_json_object_from_request(
 
     if not isinstance(content, dict):
         message = "Content must be a JSON object."
-        raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
 
     return content
 
@@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
             absent.append(k)
 
     if len(absent) > 0:
-        raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
+        )
 
 
 class RestServlet:
@@ -709,7 +727,7 @@ class RestServlet:
     into the appropriate HTTP response.
     """
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         """Register this servlet with the given HTTP server."""
         patterns = getattr(self, "PATTERNS", None)
         if patterns:
@@ -758,10 +776,12 @@ class ResolveRoomIdMixin:
             resolved_room_id = room_id.to_string()
         else:
             raise SynapseError(
-                400, "%s was not legal room ID or room alias" % (room_identifier,)
+                HTTPStatus.BAD_REQUEST,
+                "%s was not legal room ID or room alias" % (room_identifier,),
             )
         if not resolved_room_id:
             raise SynapseError(
-                400, "Unknown room ID or room alias %s" % room_identifier
+                HTTPStatus.BAD_REQUEST,
+                "Unknown room ID or room alias %s" % room_identifier,
             )
         return resolved_room_id, remote_room_hosts
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 755ad56637..9f68d7e191 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Generator, Optional, Tuple, Union
+from typing import Any, Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
@@ -66,9 +66,9 @@ class SynapseRequest(Request):
         self,
         channel: HTTPChannel,
         site: "SynapseSite",
-        *args,
+        *args: Any,
         max_request_body_size: int = 1024,
-        **kw,
+        **kw: Any,
     ):
         super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
@@ -557,7 +557,7 @@ class SynapseSite(Site):
         proxied = config.http_options.x_forwarded
         request_class = XForwardedForRequest if proxied else SynapseRequest
 
-        def request_factory(channel, queued: bool) -> Request:
+        def request_factory(channel: HTTPChannel, queued: bool) -> Request:
             return request_class(
                 channel,
                 self,
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 9c85200c0f..da641aca47 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Dict
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage import Storage
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
+    my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
 
     badge = len(invites)
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 26735447a6..7912311d24 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory
 from synapse.replication.http.push import ReplicationRemovePusherRestServlet
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
+from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -113,7 +114,9 @@ class PusherPool:
         """
 
         if kind == "email":
-            email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
+            email_owner = await self.store.get_user_id_by_threepid(
+                "email", canonicalise_email(pushkey)
+            )
             if email_owner != user_id:
                 raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
 
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 7ecb446e7c..7644146dba 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen: Optional[
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 61cd7e5228..bc888ce1a8 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.lrucache import LruCache
 
@@ -25,7 +25,12 @@ if TYPE_CHECKING:
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0a58296089..a2aff75b70 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -27,7 +27,12 @@ if TYPE_CHECKING:
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 63ed50caa5..50e7379e83 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
@@ -58,7 +58,12 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 90284c202d..4d185e2b56 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
@@ -24,7 +24,12 @@ if TYPE_CHECKING:
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 497e16c69e..9d90e26375 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -26,7 +26,12 @@ if TYPE_CHECKING:
 
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c499afd4be..701c609c12 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -108,7 +108,7 @@ class VersionServlet(RestServlet):
 
 class PurgeHistoryRestServlet(RestServlet):
     PATTERNS = admin_patterns(
-        "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
+        "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet):
 
 
 class PurgeHistoryStatusRestServlet(RestServlet):
-    PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
+    PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.pagination_handler = hs.get_pagination_handler()
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 479672d4d5..6ec00ce0b9 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -22,7 +22,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
 )
 from synapse.http.site import SynapseRequest
-from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
         self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
@@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
         return HTTPStatus.OK, {"enabled": enabled}
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         body = parse_json_object_from_request(request)
 
@@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
         self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
@@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
 class BackgroundUpdateStartJobRestServlet(RestServlet):
     """Allows to start specific background updates"""
 
-    PATTERNS = admin_patterns("/background_updates/start_job")
+    PATTERNS = admin_patterns("/background_updates/start_job$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
         self._store = hs.get_datastore()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ["job_name"])
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 2e5a6600d3..d9905ff560 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             target_user.to_string(), device_id
         )
+        if device is None:
+            raise NotFoundError("No device found")
         return HTTPStatus.OK, device
 
     async def on_DELETE(
@@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
 
     def __init__(self, hs: "HomeServer"):
-        """
-        Args:
-            hs: server
-        """
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5ee8b11110..38477f8ead 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
     PATTERNS = admin_patterns("/event_reports$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
     PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 744687be35..50d88c9109 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
         200 OK with details of a destination if success otherwise an error.
     """
 
-    PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
+    PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index a27110388f..cd697e180e 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 class DeleteGroupAdminRestServlet(RestServlet):
     """Allows deleting of local groups"""
 
-    PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
+    PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 9e23e2d8fc..7236e4027f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,7 +17,7 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
     """
 
     PATTERNS = [
-        *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
+        *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
         # This path kept around for legacy reasons
-        *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
+        *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
     ]
 
     def __init__(self, hs: "HomeServer"):
@@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
     this server.
     """
 
-    PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
+    PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
     """
 
     PATTERNS = admin_patterns(
-        "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+        "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
     """
 
     PATTERNS = admin_patterns(
-        "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+        "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info(
             "Remove from quarantine local media by ID: %s/%s", server_name, media_id
@@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
 class ProtectMediaByID(RestServlet):
     """Protect local media from being quarantined."""
 
-    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info("Protecting local media by ID: %s", media_id)
 
@@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
 class UnprotectMediaByID(RestServlet):
     """Unprotect local media from being quarantined."""
 
-    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info("Unprotecting local media by ID: %s", media_id)
 
@@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room."""
 
-    PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
+    PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        is_admin = await self.auth.is_server_admin(requester.user)
-        if not is_admin:
-            raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
+        await assert_requester_is_admin(self.auth, request)
 
         local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
 
@@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
 class DeleteMediaByID(RestServlet):
     """Delete local media by a given ID. Removes it from this server."""
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
     timestamp and size.
     """
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
+    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
         media that exist given for this user
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
 
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
@@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
                 request,
                 "order_by",
                 default=MediaSortOrder.CREATED_TS.value,
-                allowed_values=(
-                    MediaSortOrder.MEDIA_ID.value,
-                    MediaSortOrder.UPLOAD_NAME.value,
-                    MediaSortOrder.CREATED_TS.value,
-                    MediaSortOrder.LAST_ACCESS_TS.value,
-                    MediaSortOrder.MEDIA_LENGTH.value,
-                    MediaSortOrder.MEDIA_TYPE.value,
-                    MediaSortOrder.QUARANTINED_BY.value,
-                    MediaSortOrder.SAFE_FROM_QUARANTINE.value,
-                ),
+                allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
             direction = parse_string(
                 request, "dir", default="f", allowed_values=("f", "b")
@@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
                 request,
                 "order_by",
                 default=MediaSortOrder.CREATED_TS.value,
-                allowed_values=(
-                    MediaSortOrder.MEDIA_ID.value,
-                    MediaSortOrder.UPLOAD_NAME.value,
-                    MediaSortOrder.CREATED_TS.value,
-                    MediaSortOrder.LAST_ACCESS_TS.value,
-                    MediaSortOrder.MEDIA_LENGTH.value,
-                    MediaSortOrder.MEDIA_TYPE.value,
-                    MediaSortOrder.QUARANTINED_BY.value,
-                    MediaSortOrder.SAFE_FROM_QUARANTINE.value,
-                ),
+                allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
             direction = parse_string(
                 request, "dir", default="f", allowed_values=("f", "b")
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 891b98c088..04948b6408 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens/new$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 829e86675a..17c6df1cc8 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
     If 'purge' is true, it will remove all traces of a room from the database.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
 class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
     """Get the status of the delete room background task."""
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
 class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
     """Get the status of the delete room background task."""
 
-    PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
         self.admin_handler = hs.get_admin_handler()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         # Extract query parameters
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
-        order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
-        if order_by not in (
-            RoomSortOrder.ALPHABETICAL.value,
-            RoomSortOrder.SIZE.value,
-            RoomSortOrder.NAME.value,
-            RoomSortOrder.CANONICAL_ALIAS.value,
-            RoomSortOrder.JOINED_MEMBERS.value,
-            RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
-            RoomSortOrder.VERSION.value,
-            RoomSortOrder.CREATOR.value,
-            RoomSortOrder.ENCRYPTION.value,
-            RoomSortOrder.FEDERATABLE.value,
-            RoomSortOrder.PUBLIC.value,
-            RoomSortOrder.JOIN_RULES.value,
-            RoomSortOrder.GUEST_ACCESS.value,
-            RoomSortOrder.HISTORY_VISIBILITY.value,
-            RoomSortOrder.STATE_EVENTS.value,
-        ):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown value for order_by: %s" % (order_by,),
-                errcode=Codes.INVALID_PARAM,
-            )
+        order_by = parse_string(
+            request,
+            "order_by",
+            default=RoomSortOrder.NAME.value,
+            allowed_values=[sort_order.value for sort_order in RoomSortOrder],
+        )
 
         search_term = parse_string(request, "search_term", encoding="utf-8")
         if search_term == "":
@@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
     TODO: Add on_POST to allow room creation without joining the room
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
     Get members list of a room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
     Get full state within a room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
         if not ret:
@@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
 
 class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
 
-    PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+    PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
+        self.is_mine = hs.is_mine
 
     async def on_POST(
         self, request: SynapseRequest, room_identifier: str
@@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         assert_params_in_dict(content, ["user_id"])
         target_user = UserID.from_string(content["user_id"])
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "This endpoint can only be used with local users",
@@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
     async def on_DELETE(
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
 
@@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
 
@@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet):
     On GET: Get blocking status of room and user who has blocked this room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b295fb078b..15da9cd881 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.server_notices_manager = hs.get_server_notices_manager()
         self.admin_handler = hs.get_admin_handler()
         self.txns = HttpTransactionCache(hs)
+        self.is_mine = hs.is_mine
 
     def register(self, json_resource: HttpServer) -> None:
         PATTERN = "/send_server_notice"
@@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
             )
 
         target_user = UserID.from_string(body["user_id"])
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
             )
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index ca41fd45f2..7a6546372e 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
     PATTERNS = admin_patterns("/statistics/users/media$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         order_by = parse_string(
-            request, "order_by", default=UserSortOrder.USER_ID.value
+            request,
+            "order_by",
+            default=UserSortOrder.USER_ID.value,
+            allowed_values=(
+                UserSortOrder.MEDIA_LENGTH.value,
+                UserSortOrder.MEDIA_COUNT.value,
+                UserSortOrder.USER_ID.value,
+                UserSortOrder.DISPLAYNAME.value,
+            ),
         )
-        if order_by not in (
-            UserSortOrder.MEDIA_LENGTH.value,
-            UserSortOrder.MEDIA_COUNT.value,
-            UserSortOrder.USER_ID.value,
-            UserSortOrder.DISPLAYNAME.value,
-        ):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown value for order_by: %s" % (order_by,),
-                errcode=Codes.INVALID_PARAM,
-            )
 
         start = parse_integer(request, "from", default=0)
         if start < 0:
diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py
index 2bf1472967..5353dc3682 100644
--- a/synapse/rest/admin/username_available.py
+++ b/synapse/rest/admin/username_available.py
@@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/username_available")
+    PATTERNS = admin_patterns("/username_available$")
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2a60b602b1..db678da4cf 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
@@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
 
 
 class UserRestServletV2(RestServlet):
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
 
     """Get request to list user details.
     This needs user to have administrator access in Synapse.
@@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
              nonce to the time it was generated, in int seconds.
     """
 
-    PATTERNS = admin_patterns("/register")
+    PATTERNS = admin_patterns("/register$")
     NONCE_TIMEOUT = 60
 
     def __init__(self, hs: "HomeServer"):
@@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
     ]
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
         if target_user != auth_user:
             await assert_user_is_admin(self.auth, auth_user)
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
 
         ret = await self.admin_handler.get_whois(target_user)
@@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
 
 
 class DeactivateAccountRestServlet(RestServlet):
-    PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
     PATTERNS = admin_patterns("/account_validity/validity$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.account_activity_handler = hs.get_account_validity_handler()
         self.auth = hs.get_auth()
 
@@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
             200 OK with empty object if success otherwise an error.
     """
 
-    PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self._set_password_handler = hs.get_set_password_handler()
@@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
             200 OK with json object {list[dict[str, Any]], count} or empty object.
     """
 
-    PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, target_user_id: str
@@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
         # if not is_admin and target_user != auth_user:
         #     raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
 
         term = parse_string(request, "term", required=True)
@@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Only local users can be admins of this homeserver",
@@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
 
         assert_params_in_dict(body, ["admin"])
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Only local users can be admins of this homeserver",
@@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
     Get room list of an user.
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
 
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
@@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
         auth_user = requester.user
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
             )
@@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
         {}
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
             )
@@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
             )
@@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
@@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
             )
@@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
             )
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8566dc5cb5..ad6fd6492b 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -17,6 +17,7 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api import errors
+from synapse.api.errors import NotFoundError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -24,10 +25,9 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
 )
 from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.types import JsonDict
 
-from ._base import client_patterns, interactive_auth_handler
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             requester.user.to_string(), device_id
         )
+        if device is None:
+            raise NotFoundError("No device found")
         return 200, device
 
     @interactive_auth_handler
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index d1d8a984c6..b12a332776 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events.utils import format_event_for_client_v2_without_room_id
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, "m.read"
+            user_id, ReceiptTypes.READ
         )
 
         notif_event_ids = [pa["event_id"] for pa in push_actions]
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 43c04fac6f..f51be511d1 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
-        read_event_id = body.get("m.read", None)
+        read_event_id = body.get(ReceiptTypes.READ, None)
         hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
 
         if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
         if read_event_id:
             await self.receipts_handler.received_client_receipt(
                 room_id,
-                "m.read",
+                ReceiptTypes.READ,
                 user_id=requester.user.to_string(),
                 event_id=read_event_id,
                 hidden=hidden,
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 2b25b9aad6..b24ad2d1be 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if receipt_type != "m.read":
+        if receipt_type != ReceiptTypes.READ:
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
         # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5..ffa37ef06c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 relation_type=relation_type,
                 event_type=event_type,
                 limit=limit,
@@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_aggregation_groups_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 event_type=event_type,
                 limit=limit,
                 from_token=from_token,
@@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         # This checks that a) the event exists and b) the user is allowed to
         # view it.
-        await self.event_handler.get_event(requester.user, room_id, parent_id)
+        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
 
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
+            room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
             aggregation_key=key,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index f48e2e6ca2..60719331b6 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         state_key: str,
         txn_id: Optional[str] = None,
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if txn_id:
             set_tag("txn_id", txn_id)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 88e4f5e063..dd90ffa123 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -293,6 +293,9 @@ class SyncRestServlet(RestServlet):
         response[
             "org.matrix.msc2732.device_unused_fallback_key_types"
         ] = sync_result.device_unused_fallback_key_types
+        response[
+            "device_unused_fallback_key_types"
+        ] = sync_result.device_unused_fallback_key_types
 
         if joined:
             response["rooms"][Membership.JOIN] = joined
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 12b3ae120c..b9bfbea21b 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
@@ -99,7 +99,7 @@ class LocalKey(Resource):
             json_object = sign_json(json_object, self.config.server.server_name, key)
         return json_object
 
-    def render_GET(self, request: Request) -> int:
+    def render_GET(self, request: Request) -> Optional[int]:
         time_now = self.clock.time_msec()
         # Update the expiry time if less than half the interval remains.
         if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2a59552c20..cce1527ed9 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional
 
 import attr
 
+from synapse.rest.media.v1.preview_html import parse_html_description
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 
@@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
     if video_urls:
         open_graph_response["og:video"] = video_urls[0]
 
-    from synapse.rest.media.v1.preview_url_resource import _calc_description
-
-    description = _calc_description(tree)
+    description = parse_html_description(tree)
     if description:
         open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
new file mode 100644
index 0000000000..30b067dd42
--- /dev/null
+++ b/synapse/rest/media/v1/preview_html.py
@@ -0,0 +1,397 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import itertools
+import logging
+import re
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
+from urllib import parse as urlparse
+
+if TYPE_CHECKING:
+    from lxml import etree
+
+logger = logging.getLogger(__name__)
+
+_charset_match = re.compile(
+    br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
+_xml_encoding_match = re.compile(
+    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
+)
+_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+
+
+def _normalise_encoding(encoding: str) -> Optional[str]:
+    """Use the Python codec's name as the normalised entry."""
+    try:
+        return codecs.lookup(encoding).name
+    except LookupError:
+        return None
+
+
+def _get_html_media_encodings(
+    body: bytes, content_type: Optional[str]
+) -> Iterable[str]:
+    """
+    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
+
+    The precedence used for finding a character encoding is:
+
+    1. <meta> tag with a charset declared.
+    2. The XML document's character encoding attribute.
+    3. The Content-Type header.
+    4. Fallback to utf-8.
+    5. Fallback to windows-1252.
+
+    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
+
+    Args:
+        body: The HTML document, as bytes.
+        content_type: The Content-Type header.
+
+    Returns:
+        The character encoding of the body, as a string.
+    """
+    # There's no point in returning an encoding more than once.
+    attempted_encodings: Set[str] = set()
+
+    # Limit searches to the first 1kb, since it ought to be at the top.
+    body_start = body[:1024]
+
+    # Check if it has an encoding set in a meta tag.
+    match = _charset_match.search(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+    # Check if it has an XML document with an encoding.
+    match = _xml_encoding_match.match(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding and encoding not in attempted_encodings:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # Check the HTTP Content-Type header for a character set.
+    if content_type:
+        content_match = _content_type_match.match(content_type)
+        if content_match:
+            encoding = _normalise_encoding(content_match.group(1))
+            if encoding and encoding not in attempted_encodings:
+                attempted_encodings.add(encoding)
+                yield encoding
+
+    # Finally, fallback to UTF-8, then windows-1252.
+    for fallback in ("utf-8", "cp1252"):
+        if fallback not in attempted_encodings:
+            yield fallback
+
+
+def decode_body(
+    body: bytes, uri: str, content_type: Optional[str] = None
+) -> Optional["etree.Element"]:
+    """
+    This uses lxml to parse the HTML document.
+
+    Args:
+        body: The HTML document, as bytes.
+        uri: The URI used to download the body.
+        content_type: The Content-Type header.
+
+    Returns:
+        The parsed HTML body, or None if an error occurred during processed.
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return None
+
+    # The idea here is that multiple encodings are tried until one works.
+    # Unfortunately the result is never used and then LXML will decode the string
+    # again with the found encoding.
+    for encoding in _get_html_media_encodings(body, content_type):
+        try:
+            body.decode(encoding)
+        except Exception:
+            pass
+        else:
+            break
+    else:
+        logger.warning("Unable to decode HTML body for %s", uri)
+        return None
+
+    from lxml import etree
+
+    # Create an HTML parser.
+    parser = etree.HTMLParser(recover=True, encoding=encoding)
+
+    # Attempt to parse the body. Returns None if the body was successfully
+    # parsed, but no tree was found.
+    return etree.fromstring(body, parser)
+
+
+def parse_html_to_open_graph(
+    tree: "etree.Element", media_uri: str
+) -> Dict[str, Optional[str]]:
+    """
+    Parse the HTML document into an Open Graph response.
+
+    This uses lxml to search the HTML document for Open Graph data (or
+    synthesizes it from the document).
+
+    Args:
+        tree: The parsed HTML document.
+        media_url: The URI used to download the body.
+
+    Returns:
+        The Open Graph response as a dictionary.
+    """
+
+    # if we see any image URLs in the OG response, then spider them
+    # (although the client could choose to do this by asking for previews of those
+    # URLs to avoid DoSing the server)
+
+    # "og:type"         : "video",
+    # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+    # "og:site_name"    : "YouTube",
+    # "og:video:type"   : "application/x-shockwave-flash",
+    # "og:description"  : "Fun stuff happening here",
+    # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+    # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+    # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+    # "og:video:width"  : "1280"
+    # "og:video:height" : "720",
+    # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+    og: Dict[str, Optional[str]] = {}
+    for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+        if "content" in tag.attrib:
+            # if we've got more than 50 tags, someone is taking the piss
+            if len(og) >= 50:
+                logger.warning("Skipping OG for page with too many 'og:' tags")
+                return {}
+            og[tag.attrib["property"]] = tag.attrib["content"]
+
+    # TODO: grab article: meta tags too, e.g.:
+
+    # "article:publisher" : "https://www.facebook.com/thethudonline" />
+    # "article:author" content="https://www.facebook.com/thethudonline" />
+    # "article:tag" content="baby" />
+    # "article:section" content="Breaking News" />
+    # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+    # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+    if "og:title" not in og:
+        # do some basic spidering of the HTML
+        title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+        if title and title[0].text is not None:
+            og["og:title"] = title[0].text.strip()
+        else:
+            og["og:title"] = None
+
+    if "og:image" not in og:
+        # TODO: extract a favicon failing all else
+        meta_image = tree.xpath(
+            "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+        )
+        if meta_image:
+            og["og:image"] = rebase_url(meta_image[0], media_uri)
+        else:
+            # TODO: consider inlined CSS styles as well as width & height attribs
+            images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+            images = sorted(
+                images,
+                key=lambda i: (
+                    -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+                ),
+            )
+            if not images:
+                images = tree.xpath("//img[@src]")
+            if images:
+                og["og:image"] = images[0].attrib["src"]
+
+    if "og:description" not in og:
+        meta_description = tree.xpath(
+            "//*/meta"
+            "[translate(@name, 'DESCRIPTION', 'description')='description']"
+            "/@content"
+        )
+        if meta_description:
+            og["og:description"] = meta_description[0]
+        else:
+            og["og:description"] = parse_html_description(tree)
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
+        og["og:description"] = summarize_paragraphs([og["og:description"]])
+
+    # TODO: delete the url downloads to stop diskfilling,
+    # as we only ever cared about its OG
+    return og
+
+
+def parse_html_description(tree: "etree.Element") -> Optional[str]:
+    """
+    Calculate a text description based on an HTML document.
+
+    Grabs any text nodes which are inside the <body/> tag, unless they are within
+    an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+    if they are within a <script/> or <style/> tag.
+
+    This is a very very very coarse approximation to a plain text render of the page.
+
+    Args:
+        tree: The parsed HTML document.
+
+    Returns:
+        The plain text description, or None if one cannot be generated.
+    """
+    # We don't just use XPATH here as that is slow on some machines.
+
+    from lxml import etree
+
+    TAGS_TO_REMOVE = (
+        "header",
+        "nav",
+        "aside",
+        "footer",
+        "script",
+        "noscript",
+        "style",
+        etree.Comment,
+    )
+
+    # Split all the text nodes into paragraphs (by splitting on new
+    # lines)
+    text_nodes = (
+        re.sub(r"\s+", "\n", el).strip()
+        for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+    )
+    return summarize_paragraphs(text_nodes)
+
+
+def _iterate_over_text(
+    tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
+    """Iterate over the tree returning text nodes in a depth first fashion,
+    skipping text nodes inside certain tags.
+    """
+    # This is basically a stack that we extend using itertools.chain.
+    # This will either consist of an element to iterate over *or* a string
+    # to be returned.
+    elements = iter([tree])
+    while True:
+        el = next(elements, None)
+        if el is None:
+            return
+
+        if isinstance(el, str):
+            yield el
+        elif el.tag not in tags_to_ignore:
+            # el.text is the text before the first child, so we can immediately
+            # return it if the text exists.
+            if el.text:
+                yield el.text
+
+            # We add to the stack all the elements children, interspersed with
+            # each child's tail text (if it exists). The tail text of a node
+            # is text that comes *after* the node, so we always include it even
+            # if we ignore the child node.
+            elements = itertools.chain(
+                itertools.chain.from_iterable(  # Basically a flatmap
+                    [child, child.tail] if child.tail else [child]
+                    for child in el.iterchildren()
+                ),
+                elements,
+            )
+
+
+def rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
+
+
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
+    """
+    Try to get a summary respecting first paragraph and then word boundaries.
+
+    Args:
+        text_nodes: The paragraphs to summarize.
+        min_size: The minimum number of words to include.
+        max_size: The maximum number of words to include.
+
+    Returns:
+        A summary of the text nodes, or None if that was not possible.
+    """
+
+    # TODO: Respect sentences?
+
+    description = ""
+
+    # Keep adding paragraphs until we get to the MIN_SIZE.
+    for text_node in text_nodes:
+        if len(description) < min_size:
+            text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+            description += text_node + "\n\n"
+        else:
+            break
+
+    description = description.strip()
+    description = re.sub(r"[\t ]+", " ", description)
+    description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
+
+    # If the concatenation of paragraphs to get above MIN_SIZE
+    # took us over MAX_SIZE, then we need to truncate mid paragraph
+    if len(description) > max_size:
+        new_desc = ""
+
+        # This splits the paragraph into words, but keeping the
+        # (preceding) whitespace intact so we can easily concat
+        # words back together.
+        for match in re.finditer(r"\s*\S+", description):
+            word = match.group()
+
+            # Keep adding words while the total length is less than
+            # MAX_SIZE.
+            if len(word) + len(new_desc) < max_size:
+                new_desc += word
+            else:
+                # At this point the next word *will* take us over
+                # MAX_SIZE, but we also want to ensure that its not
+                # a huge word. If it is add it anyway and we'll
+                # truncate later.
+                if len(new_desc) < min_size:
+                    new_desc += word
+                break
+
+        # Double check that we're not over the limit
+        if len(new_desc) > max_size:
+            new_desc = new_desc[:max_size]
+
+        # We always add an ellipsis because at the very least
+        # we chopped mid paragraph.
+        description = new_desc.strip() + "…"
+    return description if description else None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 054f3c296d..a3829d943b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,18 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import codecs
 import datetime
 import errno
 import fnmatch
-import itertools
 import logging
 import os
 import re
 import shutil
 import sys
 import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
 from urllib import parse as urlparse
 
 import attr
@@ -45,6 +43,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.preview_html import (
+    decode_body,
+    parse_html_to_open_graph,
+    rebase_url,
+)
 from synapse.types import JsonDict, UserID
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
@@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string
 from ._base import FileInfo
 
 if TYPE_CHECKING:
-    from lxml import etree
-
     from synapse.rest.media.v1.media_repository import MediaRepository
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
-_charset_match = re.compile(
-    br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
-)
-_xml_encoding_match = re.compile(
-    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
-)
-_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
-
 OG_TAG_NAME_MAXLEN = 50
 OG_TAG_VALUE_MAXLEN = 1000
 
@@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
                 # If there was no oEmbed URL (or oEmbed parsing failed), attempt
                 # to generate the Open Graph information from the HTML.
                 if not oembed_url or not og:
-                    og = _calc_og(tree, media_info.uri)
+                    og = parse_html_to_open_graph(tree, media_info.uri)
 
                 await self._precache_image_url(user, media_info, og)
             else:
@@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         # request itself and benefit from the same caching etc.  But for now we
         # just rely on the caching on the master request to speed things up.
         image_info = await self._download_url(
-            _rebase_url(og["og:image"], media_info.uri), user
+            rebase_url(og["og:image"], media_info.uri), user
         )
 
         if _is_media(image_info.media_type):
@@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def _normalise_encoding(encoding: str) -> Optional[str]:
-    """Use the Python codec's name as the normalised entry."""
-    try:
-        return codecs.lookup(encoding).name
-    except LookupError:
-        return None
-
-
-def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
-    """
-    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
-
-    The precedence used for finding a character encoding is:
-
-    1. <meta> tag with a charset declared.
-    2. The XML document's character encoding attribute.
-    3. The Content-Type header.
-    4. Fallback to utf-8.
-    5. Fallback to windows-1252.
-
-    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
-
-    Args:
-        body: The HTML document, as bytes.
-        content_type: The Content-Type header.
-
-    Returns:
-        The character encoding of the body, as a string.
-    """
-    # There's no point in returning an encoding more than once.
-    attempted_encodings: Set[str] = set()
-
-    # Limit searches to the first 1kb, since it ought to be at the top.
-    body_start = body[:1024]
-
-    # Check if it has an encoding set in a meta tag.
-    match = _charset_match.search(body_start)
-    if match:
-        encoding = _normalise_encoding(match.group(1).decode("ascii"))
-        if encoding:
-            attempted_encodings.add(encoding)
-            yield encoding
-
-    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
-
-    # Check if it has an XML document with an encoding.
-    match = _xml_encoding_match.match(body_start)
-    if match:
-        encoding = _normalise_encoding(match.group(1).decode("ascii"))
-        if encoding and encoding not in attempted_encodings:
-            attempted_encodings.add(encoding)
-            yield encoding
-
-    # Check the HTTP Content-Type header for a character set.
-    if content_type:
-        content_match = _content_type_match.match(content_type)
-        if content_match:
-            encoding = _normalise_encoding(content_match.group(1))
-            if encoding and encoding not in attempted_encodings:
-                attempted_encodings.add(encoding)
-                yield encoding
-
-    # Finally, fallback to UTF-8, then windows-1252.
-    for fallback in ("utf-8", "cp1252"):
-        if fallback not in attempted_encodings:
-            yield fallback
-
-
-def decode_body(
-    body: bytes, uri: str, content_type: Optional[str] = None
-) -> Optional["etree.Element"]:
-    """
-    This uses lxml to parse the HTML document.
-
-    Args:
-        body: The HTML document, as bytes.
-        uri: The URI used to download the body.
-        content_type: The Content-Type header.
-
-    Returns:
-        The parsed HTML body, or None if an error occurred during processed.
-    """
-    # If there's no body, nothing useful is going to be found.
-    if not body:
-        return None
-
-    # The idea here is that multiple encodings are tried until one works.
-    # Unfortunately the result is never used and then LXML will decode the string
-    # again with the found encoding.
-    for encoding in get_html_media_encodings(body, content_type):
-        try:
-            body.decode(encoding)
-        except Exception:
-            pass
-        else:
-            break
-    else:
-        logger.warning("Unable to decode HTML body for %s", uri)
-        return None
-
-    from lxml import etree
-
-    # Create an HTML parser.
-    parser = etree.HTMLParser(recover=True, encoding=encoding)
-
-    # Attempt to parse the body. Returns None if the body was successfully
-    # parsed, but no tree was found.
-    return etree.fromstring(body, parser)
-
-
-def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
-    """
-    Calculate metadata for an HTML document.
-
-    This uses lxml to search the HTML document for Open Graph data.
-
-    Args:
-        tree: The parsed HTML document.
-        media_url: The URI used to download the body.
-
-    Returns:
-        The Open Graph response as a dictionary.
-    """
-
-    # if we see any image URLs in the OG response, then spider them
-    # (although the client could choose to do this by asking for previews of those
-    # URLs to avoid DoSing the server)
-
-    # "og:type"         : "video",
-    # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
-    # "og:site_name"    : "YouTube",
-    # "og:video:type"   : "application/x-shockwave-flash",
-    # "og:description"  : "Fun stuff happening here",
-    # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
-    # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
-    # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
-    # "og:video:width"  : "1280"
-    # "og:video:height" : "720",
-    # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
-
-    og: Dict[str, Optional[str]] = {}
-    for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
-        if "content" in tag.attrib:
-            # if we've got more than 50 tags, someone is taking the piss
-            if len(og) >= 50:
-                logger.warning("Skipping OG for page with too many 'og:' tags")
-                return {}
-            og[tag.attrib["property"]] = tag.attrib["content"]
-
-    # TODO: grab article: meta tags too, e.g.:
-
-    # "article:publisher" : "https://www.facebook.com/thethudonline" />
-    # "article:author" content="https://www.facebook.com/thethudonline" />
-    # "article:tag" content="baby" />
-    # "article:section" content="Breaking News" />
-    # "article:published_time" content="2016-03-31T19:58:24+00:00" />
-    # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
-
-    if "og:title" not in og:
-        # do some basic spidering of the HTML
-        title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
-        if title and title[0].text is not None:
-            og["og:title"] = title[0].text.strip()
-        else:
-            og["og:title"] = None
-
-    if "og:image" not in og:
-        # TODO: extract a favicon failing all else
-        meta_image = tree.xpath(
-            "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
-        )
-        if meta_image:
-            og["og:image"] = _rebase_url(meta_image[0], media_uri)
-        else:
-            # TODO: consider inlined CSS styles as well as width & height attribs
-            images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
-            images = sorted(
-                images,
-                key=lambda i: (
-                    -1 * float(i.attrib["width"]) * float(i.attrib["height"])
-                ),
-            )
-            if not images:
-                images = tree.xpath("//img[@src]")
-            if images:
-                og["og:image"] = images[0].attrib["src"]
-
-    if "og:description" not in og:
-        meta_description = tree.xpath(
-            "//*/meta"
-            "[translate(@name, 'DESCRIPTION', 'description')='description']"
-            "/@content"
-        )
-        if meta_description:
-            og["og:description"] = meta_description[0]
-        else:
-            og["og:description"] = _calc_description(tree)
-    elif og["og:description"]:
-        # This must be a non-empty string at this point.
-        assert isinstance(og["og:description"], str)
-        og["og:description"] = summarize_paragraphs([og["og:description"]])
-
-    # TODO: delete the url downloads to stop diskfilling,
-    # as we only ever cared about its OG
-    return og
-
-
-def _calc_description(tree: "etree.Element") -> Optional[str]:
-    """
-    Calculate a text description based on an HTML document.
-
-    Grabs any text nodes which are inside the <body/> tag, unless they are within
-    an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
-    if they are within a <script/> or <style/> tag.
-
-    This is a very very very coarse approximation to a plain text render of the page.
-
-    Args:
-        tree: The parsed HTML document.
-
-    Returns:
-        The plain text description, or None if one cannot be generated.
-    """
-    # We don't just use XPATH here as that is slow on some machines.
-
-    from lxml import etree
-
-    TAGS_TO_REMOVE = (
-        "header",
-        "nav",
-        "aside",
-        "footer",
-        "script",
-        "noscript",
-        "style",
-        etree.Comment,
-    )
-
-    # Split all the text nodes into paragraphs (by splitting on new
-    # lines)
-    text_nodes = (
-        re.sub(r"\s+", "\n", el).strip()
-        for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
-    )
-    return summarize_paragraphs(text_nodes)
-
-
-def _iterate_over_text(
-    tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
-) -> Generator[str, None, None]:
-    """Iterate over the tree returning text nodes in a depth first fashion,
-    skipping text nodes inside certain tags.
-    """
-    # This is basically a stack that we extend using itertools.chain.
-    # This will either consist of an element to iterate over *or* a string
-    # to be returned.
-    elements = iter([tree])
-    while True:
-        el = next(elements, None)
-        if el is None:
-            return
-
-        if isinstance(el, str):
-            yield el
-        elif el.tag not in tags_to_ignore:
-            # el.text is the text before the first child, so we can immediately
-            # return it if the text exists.
-            if el.text:
-                yield el.text
-
-            # We add to the stack all the elements children, interspersed with
-            # each child's tail text (if it exists). The tail text of a node
-            # is text that comes *after* the node, so we always include it even
-            # if we ignore the child node.
-            elements = itertools.chain(
-                itertools.chain.from_iterable(  # Basically a flatmap
-                    [child, child.tail] if child.tail else [child]
-                    for child in el.iterchildren()
-                ),
-                elements,
-            )
-
-
-def _rebase_url(url: str, base: str) -> str:
-    base_parts = list(urlparse.urlparse(base))
-    url_parts = list(urlparse.urlparse(url))
-    if not url_parts[0]:  # fix up schema
-        url_parts[0] = base_parts[0] or "http"
-    if not url_parts[1]:  # fix up hostname
-        url_parts[1] = base_parts[1]
-        if not url_parts[2].startswith("/"):
-            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
-    return urlparse.urlunparse(url_parts)
-
-
 def _is_media(content_type: str) -> bool:
     return content_type.lower().startswith("image/")
 
@@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool:
 
 def _is_json(content_type: str) -> bool:
     return content_type.lower().startswith("application/json")
-
-
-def summarize_paragraphs(
-    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
-) -> Optional[str]:
-    """
-    Try to get a summary respecting first paragraph and then word boundaries.
-
-    Args:
-        text_nodes: The paragraphs to summarize.
-        min_size: The minimum number of words to include.
-        max_size: The maximum number of words to include.
-
-    Returns:
-        A summary of the text nodes, or None if that was not possible.
-    """
-
-    # TODO: Respect sentences?
-
-    description = ""
-
-    # Keep adding paragraphs until we get to the MIN_SIZE.
-    for text_node in text_nodes:
-        if len(description) < min_size:
-            text_node = re.sub(r"[\t \r\n]+", " ", text_node)
-            description += text_node + "\n\n"
-        else:
-            break
-
-    description = description.strip()
-    description = re.sub(r"[\t ]+", " ", description)
-    description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
-
-    # If the concatenation of paragraphs to get above MIN_SIZE
-    # took us over MAX_SIZE, then we need to truncate mid paragraph
-    if len(description) > max_size:
-        new_desc = ""
-
-        # This splits the paragraph into words, but keeping the
-        # (preceding) whitespace intact so we can easily concat
-        # words back together.
-        for match in re.finditer(r"\s*\S+", description):
-            word = match.group()
-
-            # Keep adding words while the total length is less than
-            # MAX_SIZE.
-            if len(word) + len(new_desc) < max_size:
-                new_desc += word
-            else:
-                # At this point the next word *will* take us over
-                # MAX_SIZE, but we also want to ensure that its not
-                # a huge word. If it is add it anyway and we'll
-                # truncate later.
-                if len(new_desc) < min_size:
-                    new_desc += word
-                break
-
-        # Double check that we're not over the limit
-        if len(new_desc) > max_size:
-            new_desc = new_desc[:max_size]
-
-        # We always add an ellipsis because at the very least
-        # we chopped mid paragraph.
-        description = new_desc.strip() + "…"
-    return description if description else None
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3056e64ff5..7967011afd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,10 +17,8 @@ import logging
 from abc import ABCMeta
 from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
 
-from synapse.storage.database import LoggingTransaction  # noqa: F401
-from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
@@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d39006..a219999f15 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 import time
+import types
 from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
@@ -175,7 +177,7 @@ class LoggingDatabaseConnection:
     def rollback(self) -> None:
         self.conn.rollback()
 
-    def __enter__(self) -> "Connection":
+    def __enter__(self) -> "LoggingDatabaseConnection":
         self.conn.__enter__()
         return self
 
@@ -526,6 +528,12 @@ class DatabasePool:
         the function will correctly handle being aborted and retried half way
         through its execution.
 
+        Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+        since they could be evaluated multiple times (which would produce an empty
+        result on the second or subsequent evaluation). Likewise, the closure of `func`
+        must not reference any generators.  This method attempts to detect such usage
+        and will log an error.
+
         Args:
             conn
             desc
@@ -536,6 +544,39 @@ class DatabasePool:
             **kwargs
         """
 
+        # Robustness check: ensure that none of the arguments are generators, since that
+        # will fail if we have to repeat the transaction.
+        # For now, we just log an error, and hope that it works on the first attempt.
+        # TODO: raise an exception.
+        for i, arg in enumerate(args):
+            if inspect.isgenerator(arg):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %i to function %s",
+                    i,
+                    func,
+                )
+        for name, val in kwargs.items():
+            if inspect.isgenerator(val):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %s to function %s",
+                    name,
+                    func,
+                )
+        # also check variables referenced in func's closure
+        if inspect.isfunction(func):
+            f = cast(types.FunctionType, func)
+            if f.__closure__:
+                for i, cell in enumerate(f.__closure__):
+                    if inspect.isgenerator(cell.cell_contents):
+                        logger.error(
+                            "Programming error: function %s references generator %s "
+                            "via its closure",
+                            f,
+                            f.__code__.co_freevars[i],
+                        )
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -896,6 +937,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values should be preferred for new code.
+
         Args:
             table: string giving the table name
             values: dict of new column names and values for them
@@ -909,6 +953,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values_txn should be preferred for new code.
+
         Args:
             txn: The transaction to use.
             table: string giving the table name
@@ -933,23 +980,66 @@ class DatabasePool:
             if k != keys[0]:
                 raise RuntimeError("All items must have the same keys")
 
+        return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
+
+    async def simple_insert_many_values(
+        self,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+        desc: str,
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+            desc: description of the transaction, for logging and metrics
+        """
+        await self.runInteraction(
+            desc, self.simple_insert_many_values_txn, table, keys, values
+        )
+
+    @staticmethod
+    def simple_insert_many_values_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+        """
+
         if isinstance(txn.database_engine, PostgresEngine):
             # We use `execute_values` as it can be a lot faster than `execute_batch`,
             # but it's only available on postgres.
             sql = "INSERT INTO %s (%s) VALUES ?" % (
                 table,
-                ", ".join(k for k in keys[0]),
+                ", ".join(k for k in keys),
             )
 
-            txn.execute_values(sql, vals, fetch=False)
+            txn.execute_values(sql, values, fetch=False)
         else:
             sql = "INSERT INTO %s (%s) VALUES(%s)" % (
                 table,
-                ", ".join(k for k in keys[0]),
-                ", ".join("?" for _ in keys[0]),
+                ", ".join(k for k in keys),
+                ", ".join("?" for _ in keys),
             )
 
-            txn.execute_batch(sql, vals)
+            txn.execute_batch(sql, values)
 
     async def simple_upsert(
         self,
@@ -1177,9 +1267,9 @@ class DatabasePool:
         self,
         table: str,
         key_names: Collection[str],
-        key_values: Collection[Iterable[Any]],
+        key_values: Collection[Collection[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[Any]],
+        value_values: Collection[Collection[Any]],
         desc: str,
     ) -> None:
         """
@@ -1871,7 +1961,7 @@ class DatabasePool:
         self,
         table: str,
         column: str,
-        iterable: Iterable[Any],
+        iterable: Collection[Any],
         keyvalues: Dict[str, Any],
         desc: str,
     ) -> int:
@@ -1882,7 +1972,8 @@ class DatabasePool:
         Args:
             table: string giving the table name
             column: column name to test for inclusion against `iterable`
-            iterable: list
+            iterable: list of values to match against `column`. NB cannot be a generator
+                as it may be evaluated multiple times.
             keyvalues: dict of column names and values to select the rows with
             desc: description of the transaction, for logging and metrics
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..716b25dd34 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -129,7 +129,12 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
@@ -143,9 +148,6 @@ class DataStore(
                 ("device_lists_outbound_pokes", "stream_id"),
             ],
         )
-        self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn, "e2e_cross_signing_keys", "stream_id"
-        )
 
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage._base import db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -34,13 +44,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(SQLBaseStore):
-    """This is an abstract base class where subclasses must implement
-    `get_max_account_data_stream_id` which can be called in the initializer.
-    """
+class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        self._instance_name = hs.get_instance_name()
+        # `_can_write_to_account_data` indicates whether the current worker is allowed
+        # to write account data. A value of `True` implies that `_account_data_id_gen`
+        # is an `AbstractStreamIdGenerator` and not just a tracker.
+        self._account_data_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_account_data = (
@@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore):
                 writers=hs.config.worker.writers.account_data,
             )
         else:
-            self._can_write_to_account_data = True
-
             # We shouldn't be running in worker mode with SQLite, but its useful
             # to support it for unit tests.
             #
@@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.account_data:
+            if self._instance_name in hs.config.worker.writers.account_data:
+                self._can_write_to_account_data = True
                 self._account_data_id_gen = StreamIdGenerator(
                     db_conn,
                     "room_account_data",
@@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore):
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super().__init__(database, db_conn, hs)
-
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
@@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             room_id string to per room account_data dicts.
         """
 
-        def get_account_data_for_user_txn(txn):
+        def get_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "account_data",
@@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 ["room_id", "account_data_type", "content"],
             )
 
-            by_room = {}
+            by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in rows:
                 room_data = by_room.setdefault(row["room_id"], {})
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
@@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             A dict of the room account_data
         """
 
-        def get_account_data_for_room_txn(txn):
+        def get_account_data_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, JsonDict]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "room_account_data",
@@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             The room account_data for that type, or None if there isn't any set.
         """
 
-        def get_account_data_for_room_and_type_txn(txn):
+        def get_account_data_for_room_and_type_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[JsonDict]:
             content_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
@@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_global_account_data_txn(txn):
+        def get_updated_global_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_global_account_data", get_updated_global_account_data_txn
@@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_room_account_data_txn(txn):
+        def get_updated_room_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_room_account_data", get_updated_room_account_data_txn
@@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             mapping from room_id string to per room account_data dicts.
         """
 
-        def get_updated_account_data_for_user_txn(txn):
+        def get_updated_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             sql = (
                 "SELECT account_data_type, content FROM account_data"
                 " WHERE user_id = ? AND stream_id > ?"
@@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (user_id, stream_id))
 
-            account_data_by_room = {}
+            account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
@@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore):
             )
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == TagAccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         elif stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
@@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore):
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
     def _add_account_data_for_user(
         self,
-        txn,
+        txn: LoggingTransaction,
         next_id: int,
         user_id: str,
         account_data_type: str,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import (
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -58,7 +57,12 @@ def _make_exclusive_regex(
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.appservice.app_service_config_files
         )
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
 )
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
 from synapse.events.utils import prune_event_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.util import json_encoder
@@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
 
 
 class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if (
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 1dc7f0ebe3..8b0c614ece 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
@@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
 
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
 
 class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
 
         # (user_id, access_token, ip,) -> last_seen
         self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..b410eefdc7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
     REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..3932599988 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -101,7 +107,9 @@ class DeviceWorkerStore(SQLBaseStore):
             "count_devices_by_users", count_devices_by_users_txn, user_ids
         )
 
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+    async def get_device(
+        self, user_id: str, device_id: str
+    ) -> Optional[Dict[str, Any]]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -109,15 +117,15 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            A dict containing the device information
-        Raises:
-            StoreError: if the device is not found
+            A dict containing the device information, or `None` if the device does not
+            exist.
         """
         return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
+            allow_none=True,
         )
 
     async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
@@ -274,7 +282,9 @@ class DeviceWorkerStore(SQLBaseStore):
         # add the updated cross-signing keys to the results list
         for user_id, result in cross_signing_keys_by_user.items():
             result["user_id"] = user_id
-            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("m.signing_key_update", result))
+            # also send the unstable version
+            # FIXME: remove this when enough servers have upgraded
             results.append(("org.matrix.signing_key_update", result))
 
         return now_stream_id, results
@@ -949,7 +959,12 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -1081,7 +1096,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 import attr
 from canonicaljson import encode_canonical_json
 
-from twisted.enterprise.adbapi import Connection
-
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
 
 
 class EndToEndKeyBackgroundStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
         )
 
 
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
-        rv = {}
+        rv: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             # add each cross-signing signature to the correct device in the result dict.
             for (user_id, key_id, device_id, signature) in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
+                # We've only looked up cross-signatures for non-deleted devices with key
+                # data.
+                assert target_device_result is not None
+                assert target_device_result.keys is not None
                 target_device_signatures = target_device_result.keys.setdefault(
                     "signatures", {}
                 )
@@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+        self,
+        txn: LoggingTransaction,
+        query_list: Collection[Tuple[str, str]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         """Get information on devices from the database
 
@@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_cross_signing_signatures_for_devices_txn(
-        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+        self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
     ) -> List[Tuple[str, str, str, str]]:
         """Get cross-signing signatures for a given list of devices
 
@@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
         txn.execute(signature_sql, signature_query_params)
-        return txn.fetchall()
+        return cast(
+            List[
+                Tuple[
+                    str,
+                    str,
+                    str,
+                    str,
+                ]
+            ],
+            txn.fetchall(),
+        )
 
     async def get_e2e_one_time_keys(
         self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn):
+        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("new_keys", new_keys)
@@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             A mapping from algorithm to number of keys for that algorithm.
         """
 
-        def _count_e2e_one_time_keys(txn):
+        def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
             sql = (
                 "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ?"
@@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
     def _set_e2e_fallback_keys_txn(
-        self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        fallback_keys: JsonDict,
     ) -> None:
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """Returns a user's cross-signing key.
 
         Args:
@@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id):
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             their user ID will map to None.
 
         """
-        return await self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
         )
 
+        # The `Optional` comes from the `@cachedList` decorator.
+        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
-        txn: Connection,
+        txn: LoggingTransaction,
         user_ids: Iterable[str],
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Dict[str, JsonDict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_ids (list[str]): the users whose keys are being requested
+            txn: db connection
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, their user
-                ID will not be in the dict.
+            Mapping from user ID to key type to key data.
+            If a user's cross-signing keys were not found, their user ID will not be in
+            the dict.
 
         """
-        result = {}
+        result: Dict[str, Dict[str, JsonDict]] = {}
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
@@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
                 user_id = row["user_id"]
                 key_type = row["keytype"]
                 key = db_to_json(row["keydata"])
-                user_info = result.setdefault(user_id, {})
-                user_info[key_type] = key
+                user_keys = result.setdefault(user_id, {})
+                user_keys[key_type] = key
 
         return result
 
     def _get_e2e_cross_signing_signatures_txn(
         self,
-        txn: Connection,
-        keys: Dict[str, Dict[str, dict]],
+        txn: LoggingTransaction,
+        keys: Dict[str, Optional[Dict[str, JsonDict]]],
         from_user_id: str,
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing signatures made by a user on a set of keys.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
-                key data.  This dict will be modified to add signatures.
-            from_user_id (str): fetch the signatures made by this user
+            txn: db connection
+            keys: a map of user ID to key type to key data.
+                This dict will be modified to add signatures.
+            from_user_id: fetch the signatures made by this user
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  The return value will be the same as the keys argument,
-                with the modifications included.
+            Mapping from user ID to key type to key data.
+            The return value will be the same as the keys argument, with the
+            modifications included.
         """
 
         # find out what cross-signing keys (a.k.a. devices) we need to get
         # signatures for.  This is a map of (user_id, device_id) to key type
         # (device_id is the key's public part).
-        devices = {}
+        devices: Dict[Tuple[str, str], str] = {}
 
-        for user_id, user_info in keys.items():
-            if user_info is None:
+        for user_id, user_keys in keys.items():
+            if user_keys is None:
                 continue
-            for key_type, key in user_info.items():
+            for key_type, key in user_keys.items():
                 device_id = None
                 for k in key["keys"].values():
                     device_id = k
+                # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+                # dictionary with a single entry:
+                #     "algorithm:base64_public_key": "base64_public_key"
+                # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+                assert isinstance(device_id, str)
                 devices[(user_id, device_id)] = key_type
 
         for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
             # and add the signatures to the appropriate keys
             for row in rows:
-                key_id = row["key_id"]
-                target_user_id = row["target_user_id"]
-                target_device_id = row["target_device_id"]
+                key_id: str = row["key_id"]
+                target_user_id: str = row["target_user_id"]
+                target_device_id: str = row["target_device_id"]
                 key_type = devices[(target_user_id, target_device_id)]
                 # We need to copy everything, because the result may have come
                 # from the cache.  dict.copy only does a shallow copy, so we
                 # need to recursively copy the dicts that will be modified.
-                user_info = keys[target_user_id] = keys[target_user_id].copy()
-                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                user_keys = keys[target_user_id]
+                # `user_keys` cannot be `None` because we only fetched signatures for
+                # users with keys
+                assert user_keys is not None
+                user_keys = keys[target_user_id] = user_keys.copy()
+
+                target_user_key = user_keys[key_type] = user_keys[key_type].copy()
                 if "signatures" in target_user_key:
                     signatures = target_user_key["signatures"] = target_user_key[
                         "signatures"
@@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, dict]]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def _get_all_user_signature_changes_for_remotes_txn(txn):
+        def _get_all_user_signature_changes_for_remotes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, from_user_id AS user_id
                 FROM user_signature_stream
@@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_simple(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that don't support RETURNING.
 
@@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_returning(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that support RETURNING.
 
@@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             key_id, key_json = otk_row
             return f"{algorithm}:{key_id}", key_json
 
-        results = {}
+        results: Dict[str, Dict[str, Dict[str, str]]] = {}
         for user_id, device_id, algorithm in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
@@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._cross_signing_id_gen = StreamIdGenerator(
+            db_conn, "e2e_cross_signing_keys", "stream_id"
+        )
+
     async def set_e2e_device_keys(
         self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
     ) -> bool:
@@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         or the keys were already in the database.
         """
 
-        def _set_e2e_device_keys_txn(txn):
+        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
@@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         )
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
-        def delete_e2e_keys_by_device_txn(txn):
+        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
             log_kv(
                 {
                     "message": "Deleting keys for device",
@@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+    def _set_e2e_cross_signing_key_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        key_type: str,
+        key: JsonDict,
+        stream_id: int,
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user to set the signing key for
-            key_type (str): the type of key that is being set: either 'master'
+            txn: db connection
+            user_id: the user to set the signing key for
+            key_type: the type of key that is being set: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
-            key (dict): the key data
-            stream_id (int)
+            key: the key data
+            stream_id
         """
         # the 'key' dict will look something like:
         # {
@@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(
+        self, user_id: str, key_type: str, key: JsonDict
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            user_id (str): the user to set the user-signing key for
-            key_type (str): the type of cross-signing key to set
-            key (dict): the key data
+            user_id: the user to set the user-signing key for
+            key_type: the type of cross-signing key to set
+            key: the key data
         """
 
         async with self._cross_signing_id_gen.get_next() as stream_id:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..bc5ff25d08 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         count = await self.db_pool.simple_select_one_onecol(
             table="federation_inbound_events_staging",
             keyvalues={"room_id": room_id},
-            retcol="COALESCE(COUNT(*), 0)",
+            retcol="COUNT(*)",
             desc="prune_staged_events_in_room_count",
         )
 
@@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """Update the prometheus metrics for the inbound federation staging area."""
 
         def _get_stats_for_federation_staging_txn(txn):
-            txn.execute(
-                "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
-            )
+            txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
             (count,) = txn.fetchone()
 
             txn.execute(
@@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..eacff3e432 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -20,7 +20,11 @@ from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
@@ -910,7 +919,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..81e67ece55 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@ from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
+    Collection,
     Dict,
     Generator,
     Iterable,
@@ -40,10 +41,13 @@ from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
@@ -94,7 +98,7 @@ class PersistEventsStore:
         hs: "HomeServer",
         db: DatabasePool,
         main_data_store: "DataStore",
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
     ):
         self.hs = hs
         self.db_pool = db
@@ -1319,14 +1323,13 @@ class PersistEventsStore:
 
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
-    def _store_event_txn(self, txn, events_and_contexts):
+    def _store_event_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Insert new events into the event, event_json, redaction and
         state_events tables.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
         """
 
         if not events_and_contexts:
@@ -1339,46 +1342,58 @@ class PersistEventsStore:
             d.pop("redacted_because", None)
             return d
 
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_values_txn(
             txn,
             table="event_json",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "internal_metadata": json_encoder.encode(
-                        event.internal_metadata.get_dict()
-                    ),
-                    "json": json_encoder.encode(event_dict(event)),
-                    "format_version": event.format_version,
-                }
+            keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
+            values=(
+                (
+                    event.event_id,
+                    event.room_id,
+                    json_encoder.encode(event.internal_metadata.get_dict()),
+                    json_encoder.encode(event_dict(event)),
+                    event.format_version,
+                )
                 for event, _ in events_and_contexts
-            ],
+            ),
         )
 
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_values_txn(
             txn,
             table="events",
-            values=[
-                {
-                    "instance_name": self._instance_name,
-                    "stream_ordering": event.internal_metadata.stream_ordering,
-                    "topological_ordering": event.depth,
-                    "depth": event.depth,
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "processed": True,
-                    "outlier": event.internal_metadata.is_outlier(),
-                    "origin_server_ts": int(event.origin_server_ts),
-                    "received_ts": self._clock.time_msec(),
-                    "sender": event.sender,
-                    "contains_url": (
-                        "url" in event.content and isinstance(event.content["url"], str)
-                    ),
-                }
+            keys=(
+                "instance_name",
+                "stream_ordering",
+                "topological_ordering",
+                "depth",
+                "event_id",
+                "room_id",
+                "type",
+                "processed",
+                "outlier",
+                "origin_server_ts",
+                "received_ts",
+                "sender",
+                "contains_url",
+            ),
+            values=(
+                (
+                    self._instance_name,
+                    event.internal_metadata.stream_ordering,
+                    event.depth,  # topological_ordering
+                    event.depth,  # depth
+                    event.event_id,
+                    event.room_id,
+                    event.type,
+                    True,  # processed
+                    event.internal_metadata.is_outlier(),
+                    int(event.origin_server_ts),
+                    self._clock.time_msec(),
+                    event.sender,
+                    "url" in event.content and isinstance(event.content["url"], str),
+                )
                 for event, _ in events_and_contexts
-            ],
+            ),
         )
 
         # If we're persisting an unredacted event we go and ensure
@@ -1397,27 +1412,15 @@ class PersistEventsStore:
         )
         txn.execute(sql + clause, [False] + args)
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, _ in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
+        self.db_pool.simple_insert_many_values_txn(
+            txn,
+            table="state_events",
+            keys=("event_id", "room_id", "type", "state_key"),
+            values=(
+                (event.event_id, event.room_id, event.type, event.state_key)
+                for event, _ in events_and_contexts
+                if event.is_state()
+            ),
         )
 
     def _store_rejected_events_txn(self, txn, events_and_contexts):
@@ -1780,10 +1783,14 @@ class PersistEventsStore:
         )
 
         if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+            )
 
         if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..9b36941fec 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -83,7 +84,12 @@ class _CalculateChainCover:
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         database.updates.register_background_index_update(
             update_name="local_group_updates_index",
             index_name="local_group_updates_stream_id_index",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.types import Connection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
@@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
     `last_renewed_ts` column with the current time.
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._reactor = hs.get_reactor()
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..1480a0f048 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
@@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     stats and prometheus metrics.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
@@ -100,7 +105,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         def _count_messages(txn):
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
@@ -117,7 +122,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             like_clause = "%:" + self.hs.hostname
 
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                     AND sender LIKE ?
                 AND stream_ordering > ?
@@ -134,7 +139,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_active_e2ee_rooms(self):
         def _count(txn):
             sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
@@ -156,7 +161,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         def _count_messages(txn):
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
@@ -173,7 +178,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             like_clause = "%:" + self.hs.hostname
 
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                     AND sender LIKE ?
                 AND stream_ordering > ?
@@ -190,7 +195,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_active_rooms(self):
         def _count(txn):
             sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
@@ -226,7 +231,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Returns number of users seen in the past time_from period
         """
         sql = """
-            SELECT COALESCE(count(*), 0) FROM (
+            SELECT COUNT(*) FROM (
                 SELECT user_id FROM user_ips
                 WHERE last_seen > ?
                 GROUP BY user_id
@@ -253,7 +258,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             thirty_days_ago_in_secs = now - thirty_days_in_secs
 
             sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
+                SELECT platform, COUNT(*) FROM (
                      SELECT
                         users.name, platform, users.creation_ts * 1000,
                         MAX(uip.last_seen)
@@ -291,7 +296,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 results[row[0]] = row[1]
 
             sql = """
-                SELECT COALESCE(count(*), 0) FROM (
+                SELECT COUNT(*) FROM (
                     SELECT users.name, users.creation_ts * 1000,
                                                         MAX(uip.last_seen)
                     FROM users
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f67..8f09dd8e87 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    make_in_list_sql_clause,
+)
 from synapse.util.caches.descriptors import cached
+from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -30,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
@@ -49,7 +59,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         def _count_users(txn):
             # Exclude app service users
             sql = """
-                SELECT COALESCE(count(*), 0)
+                SELECT COUNT(*)
                 FROM monthly_active_users
                     LEFT JOIN users
                     ON monthly_active_users.user_id=users.name
@@ -76,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         def _count_users_by_service(txn):
             sql = """
-                SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+                SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                 FROM monthly_active_users
                 LEFT JOIN users ON monthly_active_users.user_id=users.name
                 GROUP BY appservice_id;
@@ -103,7 +113,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             : self.hs.config.server.max_mau_value
         ]:
             user_id = await self.hs.get_datastore().get_user_id_by_threepid(
-                tp["medium"], tp["address"]
+                tp["medium"], canonicalise_email(tp["address"])
             )
             if user_id:
                 users.append(user_id)
@@ -212,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..cbf9ec38f7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         """
         # Add user entries to the table, updating the presence_stream_id column if the user already
         # exists in the table.
+        presence_stream_id = self._presence_id_gen.get_current_token()
         await self.db_pool.simple_upsert_many(
             table="users_to_send_full_presence_to",
             key_names=("user_id",),
@@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             # devices at different times, each device will receive full presence once - when
             # the presence stream ID in their sync token is less than the one in the table
             # for their user ID.
-            value_values=(
-                (self._presence_id_gen.get_current_token(),) for _ in user_ids
-            ),
+            value_values=[(presence_stream_id,) for _ in user_ids],
             desc="add_users_to_send_full_presence_to",
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -81,7 +81,12 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..bf0b903af2 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from twisted.internet import defer
 
+from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
@@ -36,7 +51,12 @@ logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
@@ -78,17 +98,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    def get_max_receipt_stream_id(self):
-        """Get the current max stream ID for receipts stream
-
-        Returns:
-            int
-        """
+    def get_max_receipt_stream_id(self) -> int:
+        """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
 
     @cached()
-    async def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = await self.get_receipts_for_room(room_id, "m.read")
+    async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+        receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
@@ -119,7 +135,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=2)
-    async def get_receipts_for_user(self, user_id, receipt_type):
+    async def get_receipts_for_user(
+        self, user_id: str, receipt_type: str
+    ) -> Dict[str, str]:
         rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +147,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
-        def f(txn):
+    async def get_receipts_for_user_with_orderings(
+        self, user_id: str, receipt_type: str
+    ) -> JsonDict:
+        def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
                 " e.topological_ordering, e.stream_ordering"
@@ -209,10 +229,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(num_args=3, tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> List[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +270,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         list_name="room_ids",
         num_args=3,
     )
-    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(
+        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+    ) -> Dict[str, List[JsonDict]]:
         if not room_ids:
             return {}
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -323,7 +345,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             A dictionary of roomids to a list of receipts.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -379,7 +401,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return defer.succeed([])
 
-        def _get_users_sent_receipts_between_txn(txn):
+        def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
                 SELECT DISTINCT user_id FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +441,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_receipts_txn(txn):
+        def get_all_updated_receipts_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, room_id, receipt_type, user_id, event_id, data
                 FROM receipts_linearized
@@ -446,8 +470,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_users_with_receipts_in_room(
         self, room_id: str, receipt_type: str, user_id: str
-    ):
-        if receipt_type != "m.read":
+    ) -> None:
+        if receipt_type != ReceiptTypes.READ:
             return
 
         res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +485,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+    def invalidate_caches_for_receipt(
+        self, room_id: str, receipt_type: str, user_id: str
+    ) -> None:
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +508,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_id: str,
+        data: JsonDict,
+        stream_id: int,
+    ) -> Optional[int]:
         """Inserts a read-receipt into the database if it's newer than the current RR
 
-        Returns: int|None
+        Returns:
             None if the RR is older than the current RR
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
@@ -550,7 +583,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        if receipt_type == "m.read" and stream_ordering is not None:
+        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
@@ -580,7 +613,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         else:
             # we need to points in graph -> linearized form.
             # TODO: Make this better.
-            def graph_to_linear(txn):
+            def graph_to_linear(txn: LoggingTransaction) -> str:
                 clause, args = make_in_list_sql_clause(
                     self.database_engine, "event_id", event_ids
                 )
@@ -634,11 +667,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return stream_id, max_persisted_id
 
     async def insert_graph_receipt(
-        self, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -649,8 +687,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def insert_graph_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..29d9d4de96 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -794,7 +794,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
             sql = """
-                SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+                SELECT user_type, COUNT(*) AS count FROM (
                     SELECT
                     CASE
                         WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
@@ -819,7 +819,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         def _count_users(txn):
             txn.execute(
                 """
-                SELECT COALESCE(COUNT(*), 0) FROM users
+                SELECT COUNT(*) FROM users
                 WHERE appservice_id IS NULL
             """
             )
@@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         Args:
             medium: threepid medium e.g. email
-            address: threepid address e.g. me@example.com
+            address: threepid address e.g. me@example.com. This must already be
+                in canonical form.
 
         Returns:
             The user ID or None if no user id/threepid mapping exists
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..729ff17e2e 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_relations_for_event(
         self,
         event_id: str,
+        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
             the form `{"event_id": "..."}`.
         """
 
-        where_clause = ["relates_to_id = ?"]
-        where_args: List[Union[str, int]] = [event_id]
+        where_clause = ["relates_to_id = ?", "room_id = ?"]
+        where_args: List[Union[str, int]] = [event_id, room_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_aggregation_groups_for_event(
         self,
         event_id: str,
+        room_id: str,
         event_type: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
             direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+        where_args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
 
         if event_type:
             where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+    async def get_applicable_edit(
+        self, event_id: str, room_id: str
+    ) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: The original event ID
+            room_id: The original event's room ID
 
         Returns:
             The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
+                AND edit.room_id = ?
                 AND edit.type = 'm.room.message'
             ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached()
     async def get_thread_summary(
-        self, event_id: str
+        self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
         """Get the number of threaded replies, the senders of those replies, and
         the latest reply (if any) for the given event.
 
         Args:
-            event_id: The original event ID
+            event_id: Summarize the thread related to this event ID.
+            room_id: The room the event belongs to.
 
         Returns:
             The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
                 INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
                 ORDER BY topological_ordering DESC, stream_ordering DESC
                 LIMIT 1
             """
 
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             row = txn.fetchone()
             if row is None:
                 return 0, None
@@ -376,13 +390,15 @@ class RelationsWorkerStore(SQLBaseStore):
             latest_event_id = row[0]
 
             sql = """
-                SELECT COALESCE(COUNT(event_id), 0)
+                SELECT COUNT(event_id)
                 FROM event_relations
+                INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
             """
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..6cf6cc8484 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -24,7 +24,11 @@ from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, ThirdPartyInstanceID
@@ -72,7 +76,12 @@ class RoomSortOrder(Enum):
 
 
 class RoomWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -208,7 +217,7 @@ class RoomWorkerStore(SQLBaseStore):
 
             sql = """
                 SELECT
-                    COALESCE(COUNT(*), 0)
+                    COUNT(*)
                 FROM (
                     %(published_sql)s
                 ) published
@@ -1050,7 +1059,12 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1435,7 +1449,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7fe233767f..f87acfb866 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -20,7 +20,11 @@ from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if not hs.config.server.enable_search:
@@ -358,7 +367,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..4bc044fb16 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -22,7 +22,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
@@ -56,7 +60,12 @@ class _GetStateGroupDelta(
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b128..188afec332 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ class StateDeltasStore(SQLBaseStore):
         prev_stream_id = int(prev_stream_id)
 
         # check we're not going backwards
-        assert prev_stream_id <= max_stream_id
+        assert (
+            prev_stream_id <= max_stream_id
+        ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
 
         if not self._curr_state_delta_stream_cache.has_any_entity_changed(
             prev_stream_id
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..a0472e37f5 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -24,7 +24,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -96,7 +96,12 @@ class UserSortOrder(Enum):
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -533,7 +538,7 @@ class StatsStore(StateDeltasStore):
 
             txn.execute(
                 """
-                    SELECT COALESCE(count(*), 0) FROM current_state_events
+                    SELECT COUNT(*) FROM current_state_events
                     WHERE room_id = ?
                 """,
                 (room_id,),
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..9488fd5094 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_in_list_sql_clause,
 )
@@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Tuple, cast
 
+from synapse.replication.tcp.streams import TagAccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
             sql = (
@@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             next_id: The the revision to advance to.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 # than the id that the client has.
                 pass
 
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
 
 class TagsStore(TagsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..54b41513ee 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -22,7 +22,11 @@ from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -71,7 +75,12 @@ class DestinationRetryTimings:
 
 
 class TransactionWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
@@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ) -> None:
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 50d08094d5..2a3d47185a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 66  # remember to update the list below when updating
+SCHEMA_VERSION = 67  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 65:
 Changes in SCHEMA_VERSION = 66:
     - Queries on state_key columns are now disambiguated (ie, the codebase can handle
       the `events` table having a `state_key` column).
+
+Changes in SCHEMA_VERSION = 67:
+    - state_events.prev_state is no longer written to.
 """
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4ff3013908..b8112e1c05 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -74,8 +74,6 @@ class IdGenerator:
 def _load_current_id(
     db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
 ) -> int:
-    # debug logging for https://github.com/matrix-org/synapse/issues/7968
-    logger.info("initialising stream generator for %s(%s)", table, column)
     cur = db_conn.cursor(txn_name="_load_current_id")
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
@@ -86,7 +84,9 @@ def _load_current_id(
     (val,) = result
     cur.close()
     current_id = int(val) if val else step
-    return (max if step > 0 else min)(current_id, step)
+    res = (max if step > 0 else min)(current_id, step)
+    logger.info("Initialising stream generator for %s(%s): %i", table, column, res)
+    return res
 
 
 class AbstractStreamIdTracker(metaclass=abc.ABCMeta):