summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py198
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py1
-rw-r--r--synapse/config/registration.py31
-rw-r--r--synapse/crypto/keyring.py30
-rw-r--r--synapse/events/snapshot.py5
-rw-r--r--synapse/federation/federation_client.py31
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/federation/transport/client.py35
-rw-r--r--synapse/federation/transport/server/federation.py6
-rw-r--r--synapse/handlers/auth.py90
-rw-r--r--synapse/handlers/register.py49
-rw-r--r--synapse/handlers/room_summary.py14
-rw-r--r--synapse/module_api/__init__.py139
-rw-r--r--synapse/push/emailpusher.py10
-rw-r--r--synapse/push/httppusher.py3
-rw-r--r--synapse/push/mailer.py72
-rw-r--r--synapse/push/push_types.py136
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py22
-rw-r--r--synapse/replication/slave/storage/push_rule.py4
-rw-r--r--synapse/replication/tcp/streams/events.py6
-rw-r--r--synapse/rest/admin/__init__.py19
-rw-r--r--synapse/rest/admin/_base.py3
-rw-r--r--synapse/rest/admin/devices.py21
-rw-r--r--synapse/rest/admin/event_reports.py21
-rw-r--r--synapse/rest/admin/groups.py5
-rw-r--r--synapse/rest/admin/media.py53
-rw-r--r--synapse/rest/admin/registration_tokens.py51
-rw-r--r--synapse/rest/admin/rooms.py68
-rw-r--r--synapse/rest/admin/server_notice_servlet.py11
-rw-r--r--synapse/rest/admin/statistics.py21
-rw-r--r--synapse/rest/admin/users.py173
-rw-r--r--synapse/rest/client/login.py64
-rw-r--r--synapse/rest/client/register.py9
-rw-r--r--synapse/rest/client/room.py8
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/state/v1.py3
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py192
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py103
-rw-r--r--synapse/storage/databases/main/events_worker.py370
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py28
-rw-r--r--synapse/storage/persist_events.py3
-rw-r--r--synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql28
-rw-r--r--synapse/storage/util/id_generators.py116
48 files changed, 1612 insertions, 685 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a33ac34161..f7d29b4319 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
 
 """Contains constants from the specification."""
 
+from typing_extensions import Final
+
 # the max size of a (canonical-json-encoded) event
 MAX_PDU_SIZE = 65536
 
@@ -39,125 +41,125 @@ class Membership:
 
     """Represents the membership states of a user in a room."""
 
-    INVITE = "invite"
-    JOIN = "join"
-    KNOCK = "knock"
-    LEAVE = "leave"
-    BAN = "ban"
-    LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
+    INVITE: Final = "invite"
+    JOIN: Final = "join"
+    KNOCK: Final = "knock"
+    LEAVE: Final = "leave"
+    BAN: Final = "ban"
+    LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
 
 
 class PresenceState:
     """Represents the presence state of a user."""
 
-    OFFLINE = "offline"
-    UNAVAILABLE = "unavailable"
-    ONLINE = "online"
-    BUSY = "org.matrix.msc3026.busy"
+    OFFLINE: Final = "offline"
+    UNAVAILABLE: Final = "unavailable"
+    ONLINE: Final = "online"
+    BUSY: Final = "org.matrix.msc3026.busy"
 
 
 class JoinRules:
-    PUBLIC = "public"
-    KNOCK = "knock"
-    INVITE = "invite"
-    PRIVATE = "private"
+    PUBLIC: Final = "public"
+    KNOCK: Final = "knock"
+    INVITE: Final = "invite"
+    PRIVATE: Final = "private"
     # As defined for MSC3083.
-    RESTRICTED = "restricted"
+    RESTRICTED: Final = "restricted"
 
 
 class RestrictedJoinRuleTypes:
     """Understood types for the allow rules in restricted join rules."""
 
-    ROOM_MEMBERSHIP = "m.room_membership"
+    ROOM_MEMBERSHIP: Final = "m.room_membership"
 
 
 class LoginType:
-    PASSWORD = "m.login.password"
-    EMAIL_IDENTITY = "m.login.email.identity"
-    MSISDN = "m.login.msisdn"
-    RECAPTCHA = "m.login.recaptcha"
-    TERMS = "m.login.terms"
-    SSO = "m.login.sso"
-    DUMMY = "m.login.dummy"
-    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
+    PASSWORD: Final = "m.login.password"
+    EMAIL_IDENTITY: Final = "m.login.email.identity"
+    MSISDN: Final = "m.login.msisdn"
+    RECAPTCHA: Final = "m.login.recaptcha"
+    TERMS: Final = "m.login.terms"
+    SSO: Final = "m.login.sso"
+    DUMMY: Final = "m.login.dummy"
+    REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token"
 
 
 # This is used in the `type` parameter for /register when called by
 # an appservice to register a new user.
-APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
+APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service"
 
 
 class EventTypes:
-    Member = "m.room.member"
-    Create = "m.room.create"
-    Tombstone = "m.room.tombstone"
-    JoinRules = "m.room.join_rules"
-    PowerLevels = "m.room.power_levels"
-    Aliases = "m.room.aliases"
-    Redaction = "m.room.redaction"
-    ThirdPartyInvite = "m.room.third_party_invite"
-    RelatedGroups = "m.room.related_groups"
-
-    RoomHistoryVisibility = "m.room.history_visibility"
-    CanonicalAlias = "m.room.canonical_alias"
-    Encrypted = "m.room.encrypted"
-    RoomAvatar = "m.room.avatar"
-    RoomEncryption = "m.room.encryption"
-    GuestAccess = "m.room.guest_access"
+    Member: Final = "m.room.member"
+    Create: Final = "m.room.create"
+    Tombstone: Final = "m.room.tombstone"
+    JoinRules: Final = "m.room.join_rules"
+    PowerLevels: Final = "m.room.power_levels"
+    Aliases: Final = "m.room.aliases"
+    Redaction: Final = "m.room.redaction"
+    ThirdPartyInvite: Final = "m.room.third_party_invite"
+    RelatedGroups: Final = "m.room.related_groups"
+
+    RoomHistoryVisibility: Final = "m.room.history_visibility"
+    CanonicalAlias: Final = "m.room.canonical_alias"
+    Encrypted: Final = "m.room.encrypted"
+    RoomAvatar: Final = "m.room.avatar"
+    RoomEncryption: Final = "m.room.encryption"
+    GuestAccess: Final = "m.room.guest_access"
 
     # These are used for validation
-    Message = "m.room.message"
-    Topic = "m.room.topic"
-    Name = "m.room.name"
+    Message: Final = "m.room.message"
+    Topic: Final = "m.room.topic"
+    Name: Final = "m.room.name"
 
-    ServerACL = "m.room.server_acl"
-    Pinned = "m.room.pinned_events"
+    ServerACL: Final = "m.room.server_acl"
+    Pinned: Final = "m.room.pinned_events"
 
-    Retention = "m.room.retention"
+    Retention: Final = "m.room.retention"
 
-    Dummy = "org.matrix.dummy_event"
+    Dummy: Final = "org.matrix.dummy_event"
 
-    SpaceChild = "m.space.child"
-    SpaceParent = "m.space.parent"
+    SpaceChild: Final = "m.space.child"
+    SpaceParent: Final = "m.space.parent"
 
-    MSC2716_INSERTION = "org.matrix.msc2716.insertion"
-    MSC2716_BATCH = "org.matrix.msc2716.batch"
-    MSC2716_MARKER = "org.matrix.msc2716.marker"
+    MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion"
+    MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
+    MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
 
 
 class ToDeviceEventTypes:
-    RoomKeyRequest = "m.room_key_request"
+    RoomKeyRequest: Final = "m.room_key_request"
 
 
 class DeviceKeyAlgorithms:
     """Spec'd algorithms for the generation of per-device keys"""
 
-    ED25519 = "ed25519"
-    CURVE25519 = "curve25519"
-    SIGNED_CURVE25519 = "signed_curve25519"
+    ED25519: Final = "ed25519"
+    CURVE25519: Final = "curve25519"
+    SIGNED_CURVE25519: Final = "signed_curve25519"
 
 
 class EduTypes:
-    Presence = "m.presence"
+    Presence: Final = "m.presence"
 
 
 class RejectedReason:
-    AUTH_ERROR = "auth_error"
+    AUTH_ERROR: Final = "auth_error"
 
 
 class RoomCreationPreset:
-    PRIVATE_CHAT = "private_chat"
-    PUBLIC_CHAT = "public_chat"
-    TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
+    PRIVATE_CHAT: Final = "private_chat"
+    PUBLIC_CHAT: Final = "public_chat"
+    TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat"
 
 
 class ThirdPartyEntityKind:
-    USER = "user"
-    LOCATION = "location"
+    USER: Final = "user"
+    LOCATION: Final = "location"
 
 
-ServerNoticeMsgType = "m.server_notice"
-ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
+ServerNoticeMsgType: Final = "m.server_notice"
+ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached"
 
 
 class UserTypes:
@@ -165,91 +167,91 @@ class UserTypes:
     'admin' and 'guest' users should also be UserTypes. Normal users are type None
     """
 
-    SUPPORT = "support"
-    BOT = "bot"
-    ALL_USER_TYPES = (SUPPORT, BOT)
+    SUPPORT: Final = "support"
+    BOT: Final = "bot"
+    ALL_USER_TYPES: Final = (SUPPORT, BOT)
 
 
 class RelationTypes:
     """The types of relations known to this server."""
 
-    ANNOTATION = "m.annotation"
-    REPLACE = "m.replace"
-    REFERENCE = "m.reference"
-    THREAD = "io.element.thread"
+    ANNOTATION: Final = "m.annotation"
+    REPLACE: Final = "m.replace"
+    REFERENCE: Final = "m.reference"
+    THREAD: Final = "io.element.thread"
 
 
 class LimitBlockingTypes:
     """Reasons that a server may be blocked"""
 
-    MONTHLY_ACTIVE_USER = "monthly_active_user"
-    HS_DISABLED = "hs_disabled"
+    MONTHLY_ACTIVE_USER: Final = "monthly_active_user"
+    HS_DISABLED: Final = "hs_disabled"
 
 
 class EventContentFields:
     """Fields found in events' content, regardless of type."""
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
-    LABELS = "org.matrix.labels"
+    LABELS: Final = "org.matrix.labels"
 
     # Timestamp to delete the event after
     # cf https://github.com/matrix-org/matrix-doc/pull/2228
-    SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+    SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after"
 
     # cf https://github.com/matrix-org/matrix-doc/pull/1772
-    ROOM_TYPE = "type"
+    ROOM_TYPE: Final = "type"
 
     # Whether a room can federate.
-    FEDERATE = "m.federate"
+    FEDERATE: Final = "m.federate"
 
     # The creator of the room, as used in `m.room.create` events.
-    ROOM_CREATOR = "creator"
+    ROOM_CREATOR: Final = "creator"
 
     # Used in m.room.guest_access events.
-    GUEST_ACCESS = "guest_access"
+    GUEST_ACCESS: Final = "guest_access"
 
     # Used on normal messages to indicate they were historically imported after the fact
-    MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
+    MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
     # For "insertion" events to indicate what the next batch ID should be in
     # order to connect to it
-    MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
+    MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
     # Used on "batch" events to indicate which insertion event it connects to
-    MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
+    MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
     # For "marker" events
-    MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
+    MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
 
     # The authorising user for joining a restricted room.
-    AUTHORISING_USER = "join_authorised_via_users_server"
+    AUTHORISING_USER: Final = "join_authorised_via_users_server"
 
 
 class RoomTypes:
     """Understood values of the room_type field of m.room.create events."""
 
-    SPACE = "m.space"
+    SPACE: Final = "m.space"
 
 
 class RoomEncryptionAlgorithms:
-    MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
-    DEFAULT = MEGOLM_V1_AES_SHA2
+    MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2"
+    DEFAULT: Final = MEGOLM_V1_AES_SHA2
 
 
 class AccountDataTypes:
-    DIRECT = "m.direct"
-    IGNORED_USER_LIST = "m.ignored_user_list"
+    DIRECT: Final = "m.direct"
+    IGNORED_USER_LIST: Final = "m.ignored_user_list"
 
 
 class HistoryVisibility:
-    INVITED = "invited"
-    JOINED = "joined"
-    SHARED = "shared"
-    WORLD_READABLE = "world_readable"
+    INVITED: Final = "invited"
+    JOINED: Final = "joined"
+    SHARED: Final = "shared"
+    WORLD_READABLE: Final = "world_readable"
 
 
 class GuestAccess:
-    CAN_JOIN = "can_join"
+    CAN_JOIN: Final = "can_join"
     # anything that is not "can_join" is considered "forbidden", but for completeness:
-    FORBIDDEN = "forbidden"
+    FORBIDDEN: Final = "forbidden"
 
 
 class ReadReceiptEventFields:
-    MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
+    MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 502cc8e8d1..b4bed5bf40 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -113,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import (
 )
 from synapse.storage.databases.main.presence import PresenceStore
 from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.room_batch import RoomBatchStore
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.session import SessionStore
 from synapse.storage.databases.main.stats import StatsStore
@@ -240,6 +241,7 @@ class GenericWorkerSlavedStore(
     SlavedEventStore,
     SlavedKeyStore,
     RoomWorkerStore,
+    RoomBatchStore,
     DirectoryStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7e09530ad2..52541faab2 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -194,6 +194,7 @@ class SynapseHomeServer(HomeServer):
                 {
                     "/_matrix/client/api/v1": client_resource,
                     "/_matrix/client/r0": client_resource,
+                    "/_matrix/client/v1": client_resource,
                     "/_matrix/client/v3": client_resource,
                     "/_matrix/client/unstable": client_resource,
                     "/_matrix/client/v2_alpha": client_resource,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 61e569d412..1ddad7cb70 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Optional
 
 from synapse.api.constants import RoomCreationPreset
 from synapse.config._base import Config, ConfigError
@@ -113,32 +114,24 @@ class RegistrationConfig(Config):
         self.session_lifetime = session_lifetime
 
         # The `refreshable_access_token_lifetime` applies for tokens that can be renewed
-        # using a refresh token, as per MSC2918. If it is `None`, the refresh
-        # token mechanism is disabled.
-        #
-        # Since it is incompatible with the `session_lifetime` mechanism, it is set to
-        # `None` by default if a `session_lifetime` is set.
+        # using a refresh token, as per MSC2918.
+        # If it is `None`, the refresh token mechanism is disabled.
         refreshable_access_token_lifetime = config.get(
             "refreshable_access_token_lifetime",
-            "5m" if session_lifetime is None else None,
+            "5m",
         )
         if refreshable_access_token_lifetime is not None:
             refreshable_access_token_lifetime = self.parse_duration(
                 refreshable_access_token_lifetime
             )
-        self.refreshable_access_token_lifetime = refreshable_access_token_lifetime
-
-        if (
-            session_lifetime is not None
-            and refreshable_access_token_lifetime is not None
-        ):
-            raise ConfigError(
-                "The refresh token mechanism is incompatible with the "
-                "`session_lifetime` option. Consider disabling the "
-                "`session_lifetime` option or disabling the refresh token "
-                "mechanism by removing the `refreshable_access_token_lifetime` "
-                "option."
-            )
+        self.refreshable_access_token_lifetime: Optional[
+            int
+        ] = refreshable_access_token_lifetime
+
+        refresh_token_lifetime = config.get("refresh_token_lifetime")
+        if refresh_token_lifetime is not None:
+            refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
+        self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
 
         # The fallback template used for authenticating using a registration token
         self.registration_token_template = self.read_template("registration_token.html")
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 4cda439ad9..993b04099e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -667,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             perspective_name,
         )
 
+        request: JsonDict = {}
+        for queue_value in keys_to_fetch:
+            # there may be multiple requests for each server, so we have to merge
+            # them intelligently.
+            request_for_server = {
+                key_id: {
+                    "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                }
+                for key_id in queue_value.key_ids
+            }
+            request.setdefault(queue_value.server_name, {}).update(request_for_server)
+
+        logger.debug("Request to notary server %s: %s", perspective_name, request)
+
         try:
             query_response = await self.client.post_json(
                 destination=perspective_name,
                 path="/_matrix/key/v2/query",
-                data={
-                    "server_keys": {
-                        queue_value.server_name: {
-                            key_id: {
-                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
-                            }
-                            for key_id in queue_value.key_ids
-                        }
-                        for queue_value in keys_to_fetch
-                    }
-                },
+                data={"server_keys": request},
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             # these both have str() representations which we can't really improve upon
@@ -689,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
+        logger.debug(
+            "Response from notary server %s: %s", perspective_name, query_response
+        )
+
         keys: Dict[str, Dict[str, FetchKeyResult]] = {}
         added_keys: List[Tuple[str, str, FetchKeyResult]] = []
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index d7527008c4..f251402ed8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext):
         attributes by loading from the database.
         """
         if self.state_group is None:
+            # No state group means the event is an outlier. Usually the state_ids dicts are also
+            # pre-set to empty dicts, but they get reset when the context is serialized, so set
+            # them to empty dicts again here.
+            self._current_state_ids = {}
+            self._prev_state_ids = {}
             return
 
         current_state_ids = await self._storage.state.get_state_ids_for_group(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3b85b135e0..bc3f96c1fc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1395,11 +1395,28 @@ class FederationClient(FederationBase):
         async def send_request(
             destination: str,
         ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
-            res = await self.transport_layer.get_room_hierarchy(
-                destination=destination,
-                room_id=room_id,
-                suggested_only=suggested_only,
-            )
+            try:
+                res = await self.transport_layer.get_room_hierarchy(
+                    destination=destination,
+                    room_id=room_id,
+                    suggested_only=suggested_only,
+                )
+            except HttpResponseException as e:
+                # If an error is received that is due to an unrecognised endpoint,
+                # fallback to the unstable endpoint. Otherwise consider it a
+                # legitmate error and raise.
+                if not self._is_unknown_endpoint(e):
+                    raise
+
+                logger.debug(
+                    "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API"
+                )
+
+                res = await self.transport_layer.get_room_hierarchy_unstable(
+                    destination=destination,
+                    room_id=room_id,
+                    suggested_only=suggested_only,
+                )
 
             room = res.get("room")
             if not isinstance(room, dict):
@@ -1449,6 +1466,10 @@ class FederationClient(FederationBase):
             if e.code != 502:
                 raise
 
+            logger.debug(
+                "Couldn't fetch room hierarchy, falling back to the spaces API"
+            )
+
             # Fallback to the old federation API and translate the results if
             # no servers implement the new API.
             #
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9a8758e9a6..8fbc75aa65 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -613,8 +613,11 @@ class FederationServer(FederationBase):
         state = await self.store.get_events(state_ids)
 
         time_now = self._clock.time_msec()
+        event_json = event.get_pdu_json()
         return {
-            "org.matrix.msc3083.v2.event": event.get_pdu_json(),
+            # TODO Remove the unstable prefix when servers have updated.
+            "org.matrix.msc3083.v2.event": event_json,
+            "event": event_json,
             "state": [p.get_pdu_json(time_now) for p in state.values()],
             "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 10b5aa5af8..fe29bcfd4b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1192,10 +1192,24 @@ class TransportLayerClient:
         )
 
     async def get_room_hierarchy(
-        self,
-        destination: str,
-        room_id: str,
-        suggested_only: bool,
+        self, destination: str, room_id: str, suggested_only: bool
+    ) -> JsonDict:
+        """
+        Args:
+            destination: The remote server
+            room_id: The room ID to ask about.
+            suggested_only: if True, only suggested rooms will be returned
+        """
+        path = _create_v1_path("/hierarchy/%s", room_id)
+
+        return await self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"suggested_only": "true" if suggested_only else "false"},
+        )
+
+    async def get_room_hierarchy_unstable(
+        self, destination: str, room_id: str, suggested_only: bool
     ) -> JsonDict:
         """
         Args:
@@ -1317,15 +1331,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
             prefix + "auth_chain.item",
             use_float=True,
         )
-        self._coro_event = ijson.kvitems_coro(
+        # TODO Remove the unstable prefix when servers have updated.
+        #
+        # By re-using the same event dictionary this will cause the parsing of
+        # org.matrix.msc3083.v2.event and event to stomp over each other.
+        # Generally this should be fine.
+        self._coro_unstable_event = ijson.kvitems_coro(
             _event_parser(self._response.event_dict),
             prefix + "org.matrix.msc3083.v2.event",
             use_float=True,
         )
+        self._coro_event = ijson.kvitems_coro(
+            _event_parser(self._response.event_dict),
+            prefix + "event",
+            use_float=True,
+        )
 
     def write(self, data: bytes) -> int:
         self._coro_state.send(data)
         self._coro_auth.send(data)
+        self._coro_unstable_event.send(data)
         self._coro_event.send(data)
 
         return len(data)
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 2fdf6cc99e..66e915228c 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -611,7 +611,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
 
 
 class FederationRoomHierarchyServlet(BaseFederationServlet):
-    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
     PATH = "/hierarchy/(?P<room_id>[^/]*)"
 
     def __init__(
@@ -637,6 +636,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
         )
 
 
+class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
+
+
 class RoomComplexityServlet(BaseFederationServlet):
     """
     Indicates to other servers how complex (and therefore likely
@@ -701,6 +704,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     RoomComplexityServlet,
     FederationSpaceSummaryServlet,
     FederationRoomHierarchyServlet,
+    FederationRoomHierarchyUnstableServlet,
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
 )
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4b66a9862f..4d9c4e5834 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
 import unicodedata
 import urllib.parse
 from binascii import crc32
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -756,53 +757,109 @@ class AuthHandler:
     async def refresh_token(
         self,
         refresh_token: str,
-        valid_until_ms: Optional[int],
-    ) -> Tuple[str, str]:
+        access_token_valid_until_ms: Optional[int],
+        refresh_token_valid_until_ms: Optional[int],
+    ) -> Tuple[str, str, Optional[int]]:
         """
         Consumes a refresh token and generate both a new access token and a new refresh token from it.
 
         The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
 
+        The lifetime of both the access token and refresh token will be capped so that they
+        do not exceed the session's ultimate expiry time, if applicable.
+
         Args:
             refresh_token: The token to consume.
-            valid_until_ms: The expiration timestamp of the new access token.
-
+            access_token_valid_until_ms: The expiration timestamp of the new access token.
+                None if the access token does not expire.
+            refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+                None if the refresh token does not expire.
         Returns:
-            A tuple containing the new access token and refresh token
+            A tuple containing:
+              - the new access token
+              - the new refresh token
+              - the actual expiry time of the access token, which may be earlier than
+                `access_token_valid_until_ms`.
         """
 
         # Verify the token signature first before looking up the token
         if not self._verify_refresh_token(refresh_token):
-            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+            )
 
         existing_token = await self.store.lookup_refresh_token(refresh_token)
         if existing_token is None:
-            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED,
+                "refresh token does not exist",
+                Codes.UNKNOWN_TOKEN,
+            )
 
         if (
             existing_token.has_next_access_token_been_used
             or existing_token.has_next_refresh_token_been_refreshed
         ):
             raise SynapseError(
-                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+                HTTPStatus.FORBIDDEN,
+                "refresh token isn't valid anymore",
+                Codes.FORBIDDEN,
+            )
+
+        now_ms = self._clock.time_msec()
+
+        if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "The supplied refresh token has expired",
+                Codes.FORBIDDEN,
             )
 
+        if existing_token.ultimate_session_expiry_ts is not None:
+            # This session has a bounded lifetime, even across refreshes.
+
+            if access_token_valid_until_ms is not None:
+                access_token_valid_until_ms = min(
+                    access_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+            if refresh_token_valid_until_ms is not None:
+                refresh_token_valid_until_ms = min(
+                    refresh_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+            if existing_token.ultimate_session_expiry_ts < now_ms:
+                raise SynapseError(
+                    HTTPStatus.FORBIDDEN,
+                    "The session has expired and can no longer be refreshed",
+                    Codes.FORBIDDEN,
+                )
+
         (
             new_refresh_token,
             new_refresh_token_id,
         ) = await self.create_refresh_token_for_user_id(
-            user_id=existing_token.user_id, device_id=existing_token.device_id
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            expiry_ts=refresh_token_valid_until_ms,
+            ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
         )
         access_token = await self.create_access_token_for_user_id(
             user_id=existing_token.user_id,
             device_id=existing_token.device_id,
-            valid_until_ms=valid_until_ms,
+            valid_until_ms=access_token_valid_until_ms,
             refresh_token_id=new_refresh_token_id,
         )
         await self.store.replace_refresh_token(
             existing_token.token_id, new_refresh_token_id
         )
-        return access_token, new_refresh_token
+        return access_token, new_refresh_token, access_token_valid_until_ms
 
     def _verify_refresh_token(self, token: str) -> bool:
         """
@@ -836,6 +893,8 @@ class AuthHandler:
         self,
         user_id: str,
         device_id: str,
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> Tuple[str, int]:
         """
         Creates a new refresh token for the user with the given user ID.
@@ -843,6 +902,13 @@ class AuthHandler:
         Args:
             user_id: canonical user ID
             device_id: the device ID to associate with the token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
 
         Returns:
             The newly created refresh token and its ID in the database
@@ -852,6 +918,8 @@ class AuthHandler:
             user_id=user_id,
             token=refresh_token,
             device_id=device_id,
+            expiry_ts=expiry_ts,
+            ultimate_session_expiry_ts=ultimate_session_expiry_ts,
         )
         return refresh_token, refresh_token_id
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 448a36108e..24ca11b924 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -119,6 +119,7 @@ class RegistrationHandler:
         self.refreshable_access_token_lifetime = (
             hs.config.registration.refreshable_access_token_lifetime
         )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
         init_counters_for_auth_provider("")
 
@@ -793,13 +794,13 @@ class RegistrationHandler:
         class and RegisterDeviceReplicationServlet.
         """
         assert not self.hs.config.worker.worker_app
-        valid_until_ms = None
+        access_token_expiry = None
         if self.session_lifetime is not None:
             if is_guest:
                 raise Exception(
                     "session_lifetime is not currently implemented for guest access"
                 )
-            valid_until_ms = self.clock.time_msec() + self.session_lifetime
+            access_token_expiry = self.clock.time_msec() + self.session_lifetime
 
         refresh_token = None
         refresh_token_id = None
@@ -808,25 +809,57 @@ class RegistrationHandler:
             user_id, device_id, initial_display_name
         )
         if is_guest:
-            assert valid_until_ms is None
+            assert access_token_expiry is None
             access_token = self.macaroon_gen.generate_guest_access_token(user_id)
         else:
             if should_issue_refresh_token:
+                # A refreshable access token lifetime must be configured
+                # since we're told to issue a refresh token (the caller checks
+                # that this value is set before setting this flag).
+                assert self.refreshable_access_token_lifetime is not None
+
+                now_ms = self.clock.time_msec()
+
+                # Set the expiry time of the refreshable access token
+                access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+                # Set the refresh token expiry time (if configured)
+                refresh_token_expiry = None
+                if self.refresh_token_lifetime is not None:
+                    refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+                # Set an ultimate session expiry time (if configured)
+                ultimate_session_expiry_ts = None
+                if self.session_lifetime is not None:
+                    ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+                    # Also ensure that the issued tokens don't outlive the
+                    # session.
+                    # (It would be weird to configure a homeserver with a shorter
+                    # session lifetime than token lifetime, but may as well handle
+                    # it.)
+                    access_token_expiry = min(
+                        access_token_expiry, ultimate_session_expiry_ts
+                    )
+                    if refresh_token_expiry is not None:
+                        refresh_token_expiry = min(
+                            refresh_token_expiry, ultimate_session_expiry_ts
+                        )
+
                 (
                     refresh_token,
                     refresh_token_id,
                 ) = await self._auth_handler.create_refresh_token_for_user_id(
                     user_id,
                     device_id=registered_device_id,
-                )
-                valid_until_ms = (
-                    self.clock.time_msec() + self.refreshable_access_token_lifetime
+                    expiry_ts=refresh_token_expiry,
+                    ultimate_session_expiry_ts=ultimate_session_expiry_ts,
                 )
 
             access_token = await self._auth_handler.create_access_token_for_user_id(
                 user_id,
                 device_id=registered_device_id,
-                valid_until_ms=valid_until_ms,
+                valid_until_ms=access_token_expiry,
                 is_appservice_ghost=is_appservice_ghost,
                 refresh_token_id=refresh_token_id,
             )
@@ -834,7 +867,7 @@ class RegistrationHandler:
         return {
             "device_id": registered_device_id,
             "access_token": access_token,
-            "valid_until_ms": valid_until_ms,
+            "valid_until_ms": access_token_expiry,
             "refresh_token": refresh_token,
         }
 
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 8181cc0b52..b2cfe537df 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -93,6 +94,9 @@ class RoomSummaryHandler:
         self._event_serializer = hs.get_event_client_serializer()
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
+        self._ratelimiter = Ratelimiter(
+            store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+        )
 
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
 
     async def get_room_hierarchy(
         self,
-        requester: str,
+        requester: Requester,
         requested_room_id: str,
         suggested_only: bool = False,
         max_depth: Optional[int] = None,
@@ -276,6 +280,8 @@ class RoomSummaryHandler:
         Returns:
             The JSON hierarchy dictionary.
         """
+        await self._ratelimiter.ratelimit(requester)
+
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
         #
@@ -283,7 +289,7 @@ class RoomSummaryHandler:
         # to process multiple requests for the same page will result in errors.
         return await self._pagination_response_cache.wrap(
             (
-                requester,
+                requester.user.to_string(),
                 requested_room_id,
                 suggested_only,
                 max_depth,
@@ -291,7 +297,7 @@ class RoomSummaryHandler:
                 from_token,
             ),
             self._get_room_hierarchy,
-            requester,
+            requester.user.to_string(),
             requested_room_id,
             suggested_only,
             max_depth,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 96d7a8f2a9..a8154168be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -81,10 +82,19 @@ from synapse.http.server import (
 )
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+    defer_to_thread,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client.login import LoginResponse
 from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+    DEFAULT_BATCH_SIZE_CALLBACK,
+    MIN_BATCH_SIZE_CALLBACK,
+    ON_UPDATE_CALLBACK,
+)
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -104,6 +114,9 @@ if TYPE_CHECKING:
     from synapse.app.generic_worker import GenericWorkerSlavedStore
     from synapse.server import HomeServer
 
+
+T = TypeVar("T")
+
 """
 This package defines the 'stable' API which can be used by extension modules which
 are loaded into Synapse.
@@ -307,7 +320,25 @@ class ModuleApi:
             auth_checkers=auth_checkers,
         )
 
-    def register_web_resource(self, path: str, resource: Resource):
+    def register_background_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Registers background update controller callbacks.
+
+        Added in Synapse v1.49.0.
+        """
+
+        for db in self._hs.get_datastores().databases:
+            db.updates.register_update_controller_callbacks(
+                on_update=on_update,
+                default_batch_size=default_batch_size,
+                min_batch_size=min_batch_size,
+            )
+
+    def register_web_resource(self, path: str, resource: Resource) -> None:
         """Registers a web resource to be served at the given path.
 
         This function should be called during initialisation of the module.
@@ -432,7 +463,7 @@ class ModuleApi:
             username: provided user id
 
         Returns:
-            str: qualified @user:id
+            qualified @user:id
         """
         if username.startswith("@"):
             return username
@@ -468,7 +499,7 @@ class ModuleApi:
         """
         return await self._store.user_get_threepids(user_id)
 
-    def check_user_exists(self, user_id: str):
+    def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
         """Check if user exists.
 
         Added in Synapse v0.25.0.
@@ -477,13 +508,18 @@ class ModuleApi:
             user_id: Complete @user:id
 
         Returns:
-            Deferred[str|None]: Canonical (case-corrected) user_id, or None
+            Canonical (case-corrected) user_id, or None
                if the user is not registered.
         """
         return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
 
     @defer.inlineCallbacks
-    def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
+    def register(
+        self,
+        localpart: str,
+        displayname: Optional[str] = None,
+        emails: Optional[List[str]] = None,
+    ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]:
         """Registers a new user with given localpart and optional displayname, emails.
 
         Also returns an access token for the new user.
@@ -495,12 +531,12 @@ class ModuleApi:
         Added in Synapse v0.25.0.
 
         Args:
-            localpart (str): The localpart of the new user.
-            displayname (str|None): The displayname of the new user.
-            emails (List[str]): Emails to bind to the new user.
+            localpart: The localpart of the new user.
+            displayname: The displayname of the new user.
+            emails: Emails to bind to the new user.
 
         Returns:
-            Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+            a 2-tuple of (user_id, access_token)
         """
         logger.warning(
             "Using deprecated ModuleApi.register which creates a dummy user device."
@@ -510,23 +546,26 @@ class ModuleApi:
         return user_id, access_token
 
     def register_user(
-        self, localpart, displayname=None, emails: Optional[List[str]] = None
-    ):
+        self,
+        localpart: str,
+        displayname: Optional[str] = None,
+        emails: Optional[List[str]] = None,
+    ) -> "defer.Deferred[str]":
         """Registers a new user with given localpart and optional displayname, emails.
 
         Added in Synapse v1.2.0.
 
         Args:
-            localpart (str): The localpart of the new user.
-            displayname (str|None): The displayname of the new user.
-            emails (List[str]): Emails to bind to the new user.
+            localpart: The localpart of the new user.
+            displayname: The displayname of the new user.
+            emails: Emails to bind to the new user.
 
         Raises:
             SynapseError if there is an error performing the registration. Check the
                 'errcode' property for more information on the reason for failure
 
         Returns:
-            defer.Deferred[str]: user_id
+            user_id
         """
         return defer.ensureDeferred(
             self._hs.get_registration_handler().register_user(
@@ -536,20 +575,25 @@ class ModuleApi:
             )
         )
 
-    def register_device(self, user_id, device_id=None, initial_display_name=None):
+    def register_device(
+        self,
+        user_id: str,
+        device_id: Optional[str] = None,
+        initial_display_name: Optional[str] = None,
+    ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]":
         """Register a device for a user and generate an access token.
 
         Added in Synapse v1.2.0.
 
         Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
+            user_id: full canonical @user:id
+            device_id: The device ID to check, or None to generate
                 a new one.
-            initial_display_name (str|None): An optional display name for the
+            initial_display_name: An optional display name for the
                 device.
 
         Returns:
-            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+            Tuple of device ID, access token, access token expiration time and refresh token
         """
         return defer.ensureDeferred(
             self._hs.get_registration_handler().register_device(
@@ -603,7 +647,9 @@ class ModuleApi:
         )
 
     @defer.inlineCallbacks
-    def invalidate_access_token(self, access_token):
+    def invalidate_access_token(
+        self, access_token: str
+    ) -> Generator["defer.Deferred[Any]", Any, None]:
         """Invalidate an access token for a user
 
         Added in Synapse v0.25.0.
@@ -635,14 +681,20 @@ class ModuleApi:
                 self._auth_handler.delete_access_token(access_token)
             )
 
-    def run_db_interaction(self, desc, func, *args, **kwargs):
+    def run_db_interaction(
+        self,
+        desc: str,
+        func: Callable[..., T],
+        *args: Any,
+        **kwargs: Any,
+    ) -> "defer.Deferred[T]":
         """Run a function with a database connection
 
         Added in Synapse v0.25.0.
 
         Args:
-            desc (str): description for the transaction, for metrics etc
-            func (func): function to be run. Passed a database cursor object
+            desc: description for the transaction, for metrics etc
+            func: function to be run. Passed a database cursor object
                 as well as *args and **kwargs
             *args: positional args to be passed to func
             **kwargs: named args to be passed to func
@@ -656,7 +708,7 @@ class ModuleApi:
 
     def complete_sso_login(
         self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
-    ):
+    ) -> None:
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
         URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -686,7 +738,7 @@ class ModuleApi:
         client_redirect_url: str,
         new_user: bool = False,
         auth_provider_id: str = "<unknown>",
-    ):
+    ) -> None:
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
         URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -925,11 +977,11 @@ class ModuleApi:
         self,
         f: Callable,
         msec: float,
-        *args,
+        *args: object,
         desc: Optional[str] = None,
         run_on_all_instances: bool = False,
-        **kwargs,
-    ):
+        **kwargs: object,
+    ) -> None:
         """Wraps a function as a background process and calls it repeatedly.
 
         NOTE: Will only run on the instance that is configured to run
@@ -970,13 +1022,18 @@ class ModuleApi:
                 f,
             )
 
+    async def sleep(self, seconds: float) -> None:
+        """Sleeps for the given number of seconds."""
+
+        await self._clock.sleep(seconds)
+
     async def send_mail(
         self,
         recipient: str,
         subject: str,
         html: str,
         text: str,
-    ):
+    ) -> None:
         """Send an email on behalf of the homeserver.
 
         Added in Synapse v1.39.0.
@@ -1124,6 +1181,26 @@ class ModuleApi:
 
         return {key: state_events[event_id] for key, event_id in state_ids.items()}
 
+    async def defer_to_thread(
+        self,
+        f: Callable[..., T],
+        *args: Any,
+        **kwargs: Any,
+    ) -> T:
+        """Runs the given function in a separate thread from Synapse's thread pool.
+
+        Added in Synapse v1.49.0.
+
+        Args:
+            f: The function to run.
+            args: The function's arguments.
+            kwargs: The function's keyword arguments.
+
+        Returns:
+            The return value of the function once ran in a thread.
+        """
+        return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index cf5abdfbda..4f13c0418a 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,6 +21,8 @@ from twisted.internet.interfaces import IDelayedCall
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams
 from synapse.push.mailer import Mailer
+from synapse.push.push_types import EmailReason
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
 from synapse.util.threepids import validate_email
 
 if TYPE_CHECKING:
@@ -190,7 +192,7 @@ class EmailPusher(Pusher):
                 # we then consider all previously outstanding notifications
                 # to be delivered.
 
-                reason = {
+                reason: EmailReason = {
                     "room_id": push_action["room_id"],
                     "now": self.clock.time_msec(),
                     "received_at": received_at,
@@ -275,7 +277,7 @@ class EmailPusher(Pusher):
         return may_send_at
 
     async def sent_notif_update_throttle(
-        self, room_id: str, notified_push_action: dict
+        self, room_id: str, notified_push_action: EmailPushAction
     ) -> None:
         # We have sent a notification, so update the throttle accordingly.
         # If the event that triggered the notif happened more than
@@ -315,7 +317,9 @@ class EmailPusher(Pusher):
             self.pusher_id, room_id, self.throttle_params[room_id]
         )
 
-    async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
+    async def send_notification(
+        self, push_actions: List[EmailPushAction], reason: EmailReason
+    ) -> None:
         logger.info("Sending notif email for user %r", self.user_id)
 
         await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index dbf4ad7f97..3fa603ccb7 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -26,6 +26,7 @@ from synapse.events import EventBase
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import Pusher, PusherConfig, PusherConfigException
+from synapse.storage.databases.main.event_push_actions import HttpPushAction
 
 from . import push_rule_evaluator, push_tools
 
@@ -273,7 +274,7 @@ class HttpPusher(Pusher):
                     )
                     break
 
-    async def _process_one(self, push_action: dict) -> bool:
+    async def _process_one(self, push_action: HttpPushAction) -> bool:
         if "notify" not in push_action["actions"]:
             return True
 
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ce299ba3da..ba4f866487 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -14,7 +14,7 @@
 
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
 
 import bleach
 import jinja2
@@ -28,6 +28,14 @@ from synapse.push.presentable_names import (
     descriptor_from_member_events,
     name_from_member_event,
 )
+from synapse.push.push_types import (
+    EmailReason,
+    MessageVars,
+    NotifVars,
+    RoomVars,
+    TemplateVars,
+)
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap, UserID
 from synapse.util.async_helpers import concurrently_execute
@@ -135,7 +143,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -165,7 +173,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -196,7 +204,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -210,8 +218,8 @@ class Mailer:
         app_id: str,
         user_id: str,
         email_address: str,
-        push_actions: Iterable[Dict[str, Any]],
-        reason: Dict[str, Any],
+        push_actions: Iterable[EmailPushAction],
+        reason: EmailReason,
     ) -> None:
         """
         Send email regarding a user's room notifications
@@ -230,7 +238,7 @@ class Mailer:
             [pa["event_id"] for pa in push_actions]
         )
 
-        notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
+        notifs_by_room: Dict[str, List[EmailPushAction]] = {}
         for pa in push_actions:
             notifs_by_room.setdefault(pa["room_id"], []).append(pa)
 
@@ -258,7 +266,7 @@ class Mailer:
         # actually sort our so-called rooms_in_order list, most recent room first
         rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
 
-        rooms: List[Dict[str, Any]] = []
+        rooms: List[RoomVars] = []
 
         for r in rooms_in_order:
             roomvars = await self._get_room_vars(
@@ -289,7 +297,7 @@ class Mailer:
                 notifs_by_room, state_by_room, notif_events, reason
             )
 
-        template_vars = {
+        template_vars: TemplateVars = {
             "user_display_name": user_display_name,
             "unsubscribe_link": self._make_unsubscribe_link(
                 user_id, app_id, email_address
@@ -302,10 +310,10 @@ class Mailer:
         await self.send_email(email_address, summary_text, template_vars)
 
     async def send_email(
-        self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+        self, email_address: str, subject: str, extra_template_vars: TemplateVars
     ) -> None:
         """Send an email with the given information and template text"""
-        template_vars = {
+        template_vars: TemplateVars = {
             "app_name": self.app_name,
             "server_name": self.hs.config.server.server_name,
         }
@@ -327,10 +335,10 @@ class Mailer:
         self,
         room_id: str,
         user_id: str,
-        notifs: Iterable[Dict[str, Any]],
+        notifs: Iterable[EmailPushAction],
         notif_events: Dict[str, EventBase],
         room_state_ids: StateMap[str],
-    ) -> Dict[str, Any]:
+    ) -> RoomVars:
         """
         Generate the variables for notifications on a per-room basis.
 
@@ -356,7 +364,7 @@ class Mailer:
 
         room_name = await calculate_room_name(self.store, room_state_ids, user_id)
 
-        room_vars: Dict[str, Any] = {
+        room_vars: RoomVars = {
             "title": room_name,
             "hash": string_ordinal_total(room_id),  # See sender avatar hash
             "notifs": [],
@@ -417,11 +425,11 @@ class Mailer:
 
     async def _get_notif_vars(
         self,
-        notif: Dict[str, Any],
+        notif: EmailPushAction,
         user_id: str,
         notif_event: EventBase,
         room_state_ids: StateMap[str],
-    ) -> Dict[str, Any]:
+    ) -> NotifVars:
         """
         Generate the variables for a single notification.
 
@@ -442,7 +450,7 @@ class Mailer:
             after_limit=CONTEXT_AFTER,
         )
 
-        ret = {
+        ret: NotifVars = {
             "link": self._make_notif_link(notif),
             "ts": notif["received_ts"],
             "messages": [],
@@ -461,8 +469,8 @@ class Mailer:
         return ret
 
     async def _get_message_vars(
-        self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
-    ) -> Optional[Dict[str, Any]]:
+        self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str]
+    ) -> Optional[MessageVars]:
         """
         Generate the variables for a single event, if possible.
 
@@ -494,7 +502,9 @@ class Mailer:
 
         if sender_state_event:
             sender_name = name_from_member_event(sender_state_event)
-            sender_avatar_url = sender_state_event.content.get("avatar_url")
+            sender_avatar_url: Optional[str] = sender_state_event.content.get(
+                "avatar_url"
+            )
         else:
             # No state could be found, fallback to the MXID.
             sender_name = event.sender
@@ -504,7 +514,7 @@ class Mailer:
         # sender_hash % the number of default images to choose from
         sender_hash = string_ordinal_total(event.sender)
 
-        ret = {
+        ret: MessageVars = {
             "event_type": event.type,
             "is_historical": event.event_id != notif["event_id"],
             "id": event.event_id,
@@ -519,6 +529,8 @@ class Mailer:
             return ret
 
         msgtype = event.content.get("msgtype")
+        if not isinstance(msgtype, str):
+            msgtype = None
 
         ret["msgtype"] = msgtype
 
@@ -533,7 +545,7 @@ class Mailer:
         return ret
 
     def _add_text_message_vars(
-        self, messagevars: Dict[str, Any], event: EventBase
+        self, messagevars: MessageVars, event: EventBase
     ) -> None:
         """
         Potentially add a sanitised message body to the message variables.
@@ -543,8 +555,8 @@ class Mailer:
             event: The event under consideration.
         """
         msgformat = event.content.get("format")
-
-        messagevars["format"] = msgformat
+        if not isinstance(msgformat, str):
+            msgformat = None
 
         formatted_body = event.content.get("formatted_body")
         body = event.content.get("body")
@@ -555,7 +567,7 @@ class Mailer:
             messagevars["body_text_html"] = safe_text(body)
 
     def _add_image_message_vars(
-        self, messagevars: Dict[str, Any], event: EventBase
+        self, messagevars: MessageVars, event: EventBase
     ) -> None:
         """
         Potentially add an image URL to the message variables.
@@ -570,7 +582,7 @@ class Mailer:
     async def _make_summary_text_single_room(
         self,
         room_id: str,
-        notifs: List[Dict[str, Any]],
+        notifs: List[EmailPushAction],
         room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
         user_id: str,
@@ -685,10 +697,10 @@ class Mailer:
 
     async def _make_summary_text(
         self,
-        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        notifs_by_room: Dict[str, List[EmailPushAction]],
         room_state_ids: Dict[str, StateMap[str]],
         notif_events: Dict[str, EventBase],
-        reason: Dict[str, Any],
+        reason: EmailReason,
     ) -> str:
         """
         Make a summary text for the email when multiple rooms have notifications.
@@ -718,7 +730,7 @@ class Mailer:
     async def _make_summary_text_from_member_events(
         self,
         room_id: str,
-        notifs: List[Dict[str, Any]],
+        notifs: List[EmailPushAction],
         room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
     ) -> str:
@@ -805,7 +817,7 @@ class Mailer:
             base_url = "https://matrix.to/#"
         return "%s/%s" % (base_url, room_id)
 
-    def _make_notif_link(self, notif: Dict[str, str]) -> str:
+    def _make_notif_link(self, notif: EmailPushAction) -> str:
         """
         Generate a link to open an event in the web client.
 
diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py
new file mode 100644
index 0000000000..8d16ab62ce
--- /dev/null
+++ b/synapse/push/push_types.py
@@ -0,0 +1,136 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional
+
+from typing_extensions import TypedDict
+
+
+class EmailReason(TypedDict, total=False):
+    """
+    Information on the event that triggered the email to be sent
+
+    room_id: the ID of the room the event was sent in
+    now: timestamp in ms when the email is being sent out
+    room_name: a human-readable name for the room the event was sent in
+    received_at: the time in milliseconds at which the event was received
+    delay_before_mail_ms: the amount of time in milliseconds Synapse always waits
+            before ever emailing about a notification (to give the user a chance to respond
+            to other push or notice the window)
+    last_sent_ts: the time in milliseconds at which a notification was last sent
+            for an event in this room
+    throttle_ms: the minimum amount of time in milliseconds between two
+            notifications can be sent for this room
+    """
+
+    room_id: str
+    now: int
+    room_name: Optional[str]
+    received_at: int
+    delay_before_mail_ms: int
+    last_sent_ts: int
+    throttle_ms: int
+
+
+class MessageVars(TypedDict, total=False):
+    """
+    Details about a specific message to include in a notification
+
+    event_type: the type of the event
+    is_historical: a boolean, which is `False` if the message is the one
+                that triggered the notification, `True` otherwise
+    id: the ID of the event
+    ts: the time in milliseconds at which the event was sent
+    sender_name: the display name for the event's sender
+    sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's
+                sender
+    sender_hash: a hash of the user ID of the sender
+    msgtype: the type of the message
+    body_text_html: html representation of the message
+    body_text_plain: plaintext representation of the message
+    image_url: mxc url of an image, when "msgtype" is "m.image"
+    """
+
+    event_type: str
+    is_historical: bool
+    id: str
+    ts: int
+    sender_name: str
+    sender_avatar_url: Optional[str]
+    sender_hash: int
+    msgtype: Optional[str]
+    body_text_html: str
+    body_text_plain: str
+    image_url: str
+
+
+class NotifVars(TypedDict):
+    """
+    Details about an event we are about to include in a notification
+
+    link: a `matrix.to` link to the event
+    ts: the time in milliseconds at which the event was received
+    messages: a list of messages containing one message before the event, the
+              message in the event, and one message after the event.
+    """
+
+    link: str
+    ts: Optional[int]
+    messages: List[MessageVars]
+
+
+class RoomVars(TypedDict):
+    """
+    Represents a room containing events to include in the email.
+
+    title: a human-readable name for the room
+    hash: a hash of the ID of the room
+    invite: a boolean, which is `True` if the room is an invite the user hasn't
+        accepted yet, `False` otherwise
+    notifs: a list of events, or an empty list if `invite` is `True`.
+    link: a `matrix.to` link to the room
+    avator_url: url to the room's avator
+    """
+
+    title: Optional[str]
+    hash: int
+    invite: bool
+    notifs: List[NotifVars]
+    link: str
+    avatar_url: Optional[str]
+
+
+class TemplateVars(TypedDict, total=False):
+    """
+    Generic structure for passing to the email sender, can hold all the fields used in email templates.
+
+    app_name: name of the app/service this homeserver is associated with
+    server_name: name of our own homeserver
+    link: a link to include into the email to be sent
+    user_display_name: the display name for the user receiving the notification
+    unsubscribe_link: the link users can click to unsubscribe from email notifications
+    summary_text: a summary of the notification(s). The text used can be customised
+              by configuring the various settings in the `email.subjects` section of the
+              configuration file.
+    rooms: a list of rooms containing events to include in the email
+    reason: information on the event that triggered the email to be sent
+    """
+
+    app_name: str
+    server_name: str
+    link: str
+    user_display_name: str
+    unsubscribe_link: str
+    summary_text: str
+    rooms: List[RoomVars]
+    reason: EmailReason
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 154e5b7028..7d26954244 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,7 +86,7 @@ REQUIREMENTS = [
     # We enforce that we have a `cryptography` version that bundles an `openssl`
     # with the latest security patches.
     "cryptography>=3.4.7",
-    "ijson>=3.0",
+    "ijson>=3.1",
 ]
 
 CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
 from typing import List, Optional, Tuple
 
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
 
 
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+    """Tracks the "current" stream ID of a stream with a single writer.
+
+    See `AbstractStreamIdTracker` for more details.
+
+    Note that this class does not work correctly when there are multiple
+    writers.
+    """
+
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self) -> int:
-        """
-
-        Returns:
-            int
-        """
         return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        # We assert this for the benefit of mypy
-        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
 
 import attr
 
@@ -157,7 +157,7 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+        event_rows = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
         )
 
@@ -191,7 +191,7 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
         )
 
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ee4a5e481b..c51a029bf3 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -17,6 +17,7 @@
 
 import logging
 import platform
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Optional, Tuple
 
 import synapse
@@ -98,7 +99,7 @@ class VersionServlet(RestServlet):
         }
 
     def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        return 200, self.res
+        return HTTPStatus.OK, self.res
 
 
 class PurgeHistoryRestServlet(RestServlet):
@@ -130,7 +131,7 @@ class PurgeHistoryRestServlet(RestServlet):
             event = await self.store.get_event(event_id)
 
             if event.room_id != room_id:
-                raise SynapseError(400, "Event is for wrong room.")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.")
 
             # RoomStreamToken expects [int] not Optional[int]
             assert event.internal_metadata.stream_ordering is not None
@@ -144,7 +145,9 @@ class PurgeHistoryRestServlet(RestServlet):
             ts = body["purge_up_to_ts"]
             if not isinstance(ts, int):
                 raise SynapseError(
-                    400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
+                    HTTPStatus.BAD_REQUEST,
+                    "purge_up_to_ts must be an int",
+                    errcode=Codes.BAD_JSON,
                 )
 
             stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
@@ -160,7 +163,9 @@ class PurgeHistoryRestServlet(RestServlet):
                     stream_ordering,
                 )
                 raise SynapseError(
-                    404, "there is no event to be purged", errcode=Codes.NOT_FOUND
+                    HTTPStatus.NOT_FOUND,
+                    "there is no event to be purged",
+                    errcode=Codes.NOT_FOUND,
                 )
             (stream, topo, _event_id) = r
             token = "t%d-%d" % (topo, stream)
@@ -173,7 +178,7 @@ class PurgeHistoryRestServlet(RestServlet):
             )
         else:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "must specify purge_up_to_event_id or purge_up_to_ts",
                 errcode=Codes.BAD_JSON,
             )
@@ -182,7 +187,7 @@ class PurgeHistoryRestServlet(RestServlet):
             room_id, token, delete_local_events=delete_local_events
         )
 
-        return 200, {"purge_id": purge_id}
+        return HTTPStatus.OK, {"purge_id": purge_id}
 
 
 class PurgeHistoryStatusRestServlet(RestServlet):
@@ -201,7 +206,7 @@ class PurgeHistoryStatusRestServlet(RestServlet):
         if purge_status is None:
             raise NotFoundError("purge id '%s' not found" % purge_id)
 
-        return 200, purge_status.asdict()
+        return HTTPStatus.OK, purge_status.asdict()
 
 
 ########################################################################################
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d9a2f6ca15..399b205aaf 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import re
+from http import HTTPStatus
 from typing import Iterable, Pattern
 
 from synapse.api.auth import Auth
@@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
     """
     is_admin = await auth.is_server_admin(user_id)
     if not is_admin:
-        raise AuthError(403, "You are not a server admin")
+        raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 80fbf32f17..2e5a6600d3 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import NotFoundError, SynapseError
@@ -53,7 +54,7 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             target_user.to_string(), device_id
         )
-        return 200, device
+        return HTTPStatus.OK, device
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -71,14 +72,14 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
             raise NotFoundError("Unknown user")
 
         await self.device_handler.delete_device(target_user.to_string(), device_id)
-        return 200, {}
+        return HTTPStatus.OK, {}
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -87,7 +88,7 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet):
         await self.device_handler.update_device(
             target_user.to_string(), device_id, body
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class DevicesRestServlet(RestServlet):
@@ -124,14 +125,14 @@ class DevicesRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
             raise NotFoundError("Unknown user")
 
         devices = await self.device_handler.get_devices_by_user(target_user.to_string())
-        return 200, {"devices": devices, "total": len(devices)}
+        return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
 
 
 class DeleteDevicesRestServlet(RestServlet):
@@ -155,7 +156,7 @@ class DeleteDevicesRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -167,4 +168,4 @@ class DeleteDevicesRestServlet(RestServlet):
         await self.device_handler.delete_devices(
             target_user.to_string(), body["devices"]
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index bbfcaf723b..5ee8b11110 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -66,21 +67,23 @@ class EventReportsRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "The start parameter must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "The limit parameter must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         event_reports, total = await self.store.get_event_reports_paginate(
@@ -90,7 +93,7 @@ class EventReportsRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(event_reports)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class EventReportDetailRestServlet(RestServlet):
@@ -127,13 +130,17 @@ class EventReportDetailRestServlet(RestServlet):
         try:
             resolved_report_id = int(report_id)
         except ValueError:
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
 
         if resolved_report_id < 0:
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
 
         ret = await self.store.get_event_report(resolved_report_id)
         if not ret:
             raise NotFoundError("Event report not found")
 
-        return 200, ret
+        return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 68a3ba3cb7..a27110388f 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import SynapseError
@@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
 
         if not self.is_mine_id(group_id):
-            raise SynapseError(400, "Can only delete local groups")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups")
 
         await self.group_server.delete_group(group_id, requester.user.to_string())
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 30a687d234..9e23e2d8fc 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
@@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet):
             room_id, requester.user.to_string()
         )
 
-        return 200, {"num_quarantined": num_quarantined}
+        return HTTPStatus.OK, {"num_quarantined": num_quarantined}
 
 
 class QuarantineMediaByUser(RestServlet):
@@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet):
             user_id, requester.user.to_string()
         )
 
-        return 200, {"num_quarantined": num_quarantined}
+        return HTTPStatus.OK, {"num_quarantined": num_quarantined}
 
 
 class QuarantineMediaByID(RestServlet):
@@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet):
             server_name, media_id, requester.user.to_string()
         )
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UnquarantineMediaByID(RestServlet):
@@ -147,7 +148,7 @@ class UnquarantineMediaByID(RestServlet):
         # Remove from quarantine this media id
         await self.store.quarantine_media_by_id(server_name, media_id, None)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ProtectMediaByID(RestServlet):
@@ -170,7 +171,7 @@ class ProtectMediaByID(RestServlet):
         # Protect this media id
         await self.store.mark_local_media_as_safe(media_id, safe=True)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UnprotectMediaByID(RestServlet):
@@ -193,7 +194,7 @@ class UnprotectMediaByID(RestServlet):
         # Unprotect this media id
         await self.store.mark_local_media_as_safe(media_id, safe=False)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ListMediaInRoom(RestServlet):
@@ -211,11 +212,11 @@ class ListMediaInRoom(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
-            raise AuthError(403, "You are not a server admin")
+            raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
         local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
 
-        return 200, {"local": local_mxcs, "remote": remote_mxcs}
+        return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs}
 
 
 class PurgeMediaCacheRestServlet(RestServlet):
@@ -233,13 +234,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
 
         if before_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
         elif before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts you provided is from the year 1970. "
                 + "Double check that you are providing a timestamp in milliseconds.",
                 errcode=Codes.INVALID_PARAM,
@@ -247,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
 
         ret = await self.media_repository.delete_old_remote_media(before_ts)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class DeleteMediaByID(RestServlet):
@@ -267,7 +268,7 @@ class DeleteMediaByID(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
-            raise SynapseError(400, "Can only delete local media")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         if await self.store.get_local_media(media_id) is None:
             raise NotFoundError("Unknown media")
@@ -277,7 +278,7 @@ class DeleteMediaByID(RestServlet):
         deleted_media, total = await self.media_repository.delete_local_media_ids(
             [media_id]
         )
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 class DeleteMediaByDateSize(RestServlet):
@@ -304,26 +305,26 @@ class DeleteMediaByDateSize(RestServlet):
 
         if before_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
         elif before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts you provided is from the year 1970. "
                 + "Double check that you are providing a timestamp in milliseconds.",
                 errcode=Codes.INVALID_PARAM,
             )
         if size_gt < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter size_gt must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if self.server_name != server_name:
-            raise SynapseError(400, "Can only delete local media")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         logging.info(
             "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s"
@@ -333,7 +334,7 @@ class DeleteMediaByDateSize(RestServlet):
         deleted_media, total = await self.media_repository.delete_old_local_media(
             before_ts, size_gt, keep_profiles
         )
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 class UserMediaRestServlet(RestServlet):
@@ -369,7 +370,7 @@ class UserMediaRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         user = await self.store.get_user_by_id(user_id)
         if user is None:
@@ -380,14 +381,14 @@ class UserMediaRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -425,7 +426,7 @@ class UserMediaRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(media)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -436,7 +437,7 @@ class UserMediaRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         user = await self.store.get_user_by_id(user_id)
         if user is None:
@@ -447,14 +448,14 @@ class UserMediaRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -492,7 +493,7 @@ class UserMediaRestServlet(RestServlet):
             ([row["media_id"] for row in media])
         )
 
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index aba48f6e7b..891b98c088 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -14,6 +14,7 @@
 
 import logging
 import string
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -77,7 +78,7 @@ class ListRegistrationTokensRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
         valid = parse_boolean(request, "valid")
         token_list = await self.store.get_registration_tokens(valid)
-        return 200, {"registration_tokens": token_list}
+        return HTTPStatus.OK, {"registration_tokens": token_list}
 
 
 class NewRegistrationTokenRestServlet(RestServlet):
@@ -123,16 +124,20 @@ class NewRegistrationTokenRestServlet(RestServlet):
         if "token" in body:
             token = body["token"]
             if not isinstance(token, str):
-                raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "token must be a string",
+                    Codes.INVALID_PARAM,
+                )
             if not (0 < len(token) <= 64):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "token must not be empty and must not be longer than 64 characters",
                     Codes.INVALID_PARAM,
                 )
             if not set(token).issubset(self.allowed_chars_set):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
                     Codes.INVALID_PARAM,
                 )
@@ -142,11 +147,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
             length = body.get("length", 16)
             if not isinstance(length, int):
                 raise SynapseError(
-                    400, "length must be an integer", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "length must be an integer",
+                    Codes.INVALID_PARAM,
                 )
             if not (0 < length <= 64):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "length must be greater than zero and not greater than 64",
                     Codes.INVALID_PARAM,
                 )
@@ -162,7 +169,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
             or (isinstance(uses_allowed, int) and uses_allowed >= 0)
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "uses_allowed must be a non-negative integer or null",
                 Codes.INVALID_PARAM,
             )
@@ -170,11 +177,15 @@ class NewRegistrationTokenRestServlet(RestServlet):
         expiry_time = body.get("expiry_time", None)
         if not isinstance(expiry_time, (int, type(None))):
             raise SynapseError(
-                400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "expiry_time must be an integer or null",
+                Codes.INVALID_PARAM,
             )
         if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
             raise SynapseError(
-                400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "expiry_time must not be in the past",
+                Codes.INVALID_PARAM,
             )
 
         created = await self.store.create_registration_token(
@@ -182,7 +193,9 @@ class NewRegistrationTokenRestServlet(RestServlet):
         )
         if not created:
             raise SynapseError(
-                400, f"Token already exists: {token}", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                f"Token already exists: {token}",
+                Codes.INVALID_PARAM,
             )
 
         resp = {
@@ -192,7 +205,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
             "completed": 0,
             "expiry_time": expiry_time,
         }
-        return 200, resp
+        return HTTPStatus.OK, resp
 
 
 class RegistrationTokenRestServlet(RestServlet):
@@ -261,7 +274,7 @@ class RegistrationTokenRestServlet(RestServlet):
         if token_info is None:
             raise NotFoundError(f"No such registration token: {token}")
 
-        return 200, token_info
+        return HTTPStatus.OK, token_info
 
     async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
         """Update a registration token."""
@@ -277,7 +290,7 @@ class RegistrationTokenRestServlet(RestServlet):
                 or (isinstance(uses_allowed, int) and uses_allowed >= 0)
             ):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "uses_allowed must be a non-negative integer or null",
                     Codes.INVALID_PARAM,
                 )
@@ -287,11 +300,15 @@ class RegistrationTokenRestServlet(RestServlet):
             expiry_time = body["expiry_time"]
             if not isinstance(expiry_time, (int, type(None))):
                 raise SynapseError(
-                    400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "expiry_time must be an integer or null",
+                    Codes.INVALID_PARAM,
                 )
             if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
                 raise SynapseError(
-                    400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "expiry_time must not be in the past",
+                    Codes.INVALID_PARAM,
                 )
             new_attributes["expiry_time"] = expiry_time
 
@@ -307,7 +324,7 @@ class RegistrationTokenRestServlet(RestServlet):
         if token_info is None:
             raise NotFoundError(f"No such registration token: {token}")
 
-        return 200, token_info
+        return HTTPStatus.OK, token_info
 
     async def on_DELETE(
         self, request: SynapseRequest, token: str
@@ -316,6 +333,6 @@ class RegistrationTokenRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if await self.store.delete_registration_token(token):
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a89dda1ba5..6bbc5510f0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -102,7 +102,9 @@ class RoomRestV2Servlet(RestServlet):
             )
 
         if not RoomID.is_valid(room_id):
-            raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
 
         if not await self._store.get_room(room_id):
             raise NotFoundError("Unknown room id %s" % (room_id,))
@@ -118,7 +120,7 @@ class RoomRestV2Servlet(RestServlet):
             force_purge=force_purge,
         )
 
-        return 200, {"delete_id": delete_id}
+        return HTTPStatus.OK, {"delete_id": delete_id}
 
 
 class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
@@ -137,7 +139,9 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
         await assert_requester_is_admin(self._auth, request)
 
         if not RoomID.is_valid(room_id):
-            raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
 
         delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
         if delete_ids is None:
@@ -153,7 +157,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
                         **delete.asdict(),
                     }
                 ]
-        return 200, {"results": cast(JsonDict, response)}
+        return HTTPStatus.OK, {"results": cast(JsonDict, response)}
 
 
 class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
@@ -175,7 +179,7 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
         if delete_status is None:
             raise NotFoundError("delete id '%s' not found" % delete_id)
 
-        return 200, cast(JsonDict, delete_status.asdict())
+        return HTTPStatus.OK, cast(JsonDict, delete_status.asdict())
 
 
 class ListRoomRestServlet(RestServlet):
@@ -217,7 +221,7 @@ class ListRoomRestServlet(RestServlet):
             RoomSortOrder.STATE_EVENTS.value,
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Unknown value for order_by: %s" % (order_by,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -225,7 +229,7 @@ class ListRoomRestServlet(RestServlet):
         search_term = parse_string(request, "search_term", encoding="utf-8")
         if search_term == "":
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "search_term cannot be an empty string",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -233,7 +237,9 @@ class ListRoomRestServlet(RestServlet):
         direction = parse_string(request, "dir", default="f")
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         reverse_order = True if direction == "b" else False
@@ -265,7 +271,7 @@ class ListRoomRestServlet(RestServlet):
             else:
                 response["prev_batch"] = 0
 
-        return 200, response
+        return HTTPStatus.OK, response
 
 
 class RoomRestServlet(RestServlet):
@@ -310,7 +316,7 @@ class RoomRestServlet(RestServlet):
         members = await self.store.get_users_in_room(room_id)
         ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, room_id: str
@@ -386,7 +392,7 @@ class RoomRestServlet(RestServlet):
         # See https://github.com/python/mypy/issues/4976#issuecomment-579883622
         # for some discussion on why this is necessary. Either way,
         # `ret` is an opaque dictionary blob as far as the rest of the app cares.
-        return 200, cast(JsonDict, ret)
+        return HTTPStatus.OK, cast(JsonDict, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -413,7 +419,7 @@ class RoomMembersRestServlet(RestServlet):
         members = await self.store.get_users_in_room(room_id)
         ret = {"members": members, "total": len(members)}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class RoomStateRestServlet(RestServlet):
@@ -452,7 +458,7 @@ class RoomStateRestServlet(RestServlet):
         )
         ret = {"state": room_state}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
@@ -481,7 +487,10 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         target_user = UserID.from_string(content["user_id"])
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "This endpoint can only be used with local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "This endpoint can only be used with local users",
+            )
 
         if not await self.admin_handler.get_user(target_user):
             raise NotFoundError("User not found")
@@ -527,7 +536,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
             ratelimit=False,
         )
 
-        return 200, {"room_id": room_id}
+        return HTTPStatus.OK, {"room_id": room_id}
 
 
 class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -568,7 +577,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         # Figure out which local users currently have power in the room, if any.
         room_state = await self.state_handler.get_current_state(room_id)
         if not room_state:
-            raise SynapseError(400, "Server not in room")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
 
         create_event = room_state[(EventTypes.Create, "")]
         power_levels = room_state.get((EventTypes.PowerLevels, ""))
@@ -582,7 +591,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_users.sort(key=lambda user: user_power[user])
 
             if not admin_users:
-                raise SynapseError(400, "No local admin user in room")
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST, "No local admin user in room"
+                )
 
             admin_user_id = None
 
@@ -599,7 +610,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
 
             if not admin_user_id:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "No local admin user in room",
                 )
 
@@ -610,7 +621,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_user_id = create_event.sender
             if not self.is_mine_id(admin_user_id):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "No local admin user in room",
                 )
 
@@ -639,7 +650,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         except AuthError:
             # The admin user we found turned out not to have enough power.
             raise SynapseError(
-                400, "No local admin user in room with power to update power levels."
+                HTTPStatus.BAD_REQUEST,
+                "No local admin user in room with power to update power levels.",
             )
 
         # Now we check if the user we're granting admin rights to is already in
@@ -653,7 +665,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             )
 
         if is_joined:
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         join_rules = room_state.get((EventTypes.JoinRules, ""))
         is_public = False
@@ -661,7 +673,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
 
         if is_public:
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         await self.room_member_handler.update_membership(
             fake_requester,
@@ -670,7 +682,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             action=Membership.INVITE,
         )
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -702,7 +714,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         room_id, _ = await self.resolve_room_id(room_identifier)
 
         deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
-        return 200, {"deleted": deleted_count}
+        return HTTPStatus.OK, {"deleted": deleted_count}
 
     async def on_GET(
         self, request: SynapseRequest, room_identifier: str
@@ -713,7 +725,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         room_id, _ = await self.resolve_room_id(room_identifier)
 
         extremities = await self.store.get_forward_extremities_for_room(room_id)
-        return 200, {"count": len(extremities), "results": extremities}
+        return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
 
 
 class RoomEventContextServlet(RestServlet):
@@ -762,7 +774,9 @@ class RoomEventContextServlet(RestServlet):
         )
 
         if not results:
-            raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+            raise SynapseError(
+                HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
+            )
 
         time_now = self.clock.time_msec()
         results["events_before"] = await self._event_serializer.serialize_events(
@@ -781,7 +795,7 @@ class RoomEventContextServlet(RestServlet):
             bundle_relations=False,
         )
 
-        return 200, results
+        return HTTPStatus.OK, results
 
 
 class BlockRoomRestServlet(RestServlet):
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 19f84f33f2..b295fb078b 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
 
 from synapse.api.constants import EventTypes
@@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet):
         # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
         # admin api).
         if not self.server_notices_manager.is_enabled():
-            raise SynapseError(400, "Server notices are not enabled on this server")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server"
+            )
 
         target_user = UserID.from_string(body["user_id"])
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Server notices can only be sent to local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
+            )
 
         if not await self.admin_handler.get_user(target_user):
             raise NotFoundError("User not found")
@@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet):
             txn_id=txn_id,
         )
 
-        return 200, {"event_id": event.event_id}
+        return HTTPStatus.OK, {"event_id": event.event_id}
 
     def on_PUT(
         self, request: SynapseRequest, txn_id: str
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 948de94ccd..ca41fd45f2 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, SynapseError
@@ -53,7 +54,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
             UserSortOrder.DISPLAYNAME.value,
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Unknown value for order_by: %s" % (order_by,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -61,7 +62,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         start = parse_integer(request, "from", default=0)
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -69,7 +70,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         limit = parse_integer(request, "limit", default=100)
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -77,7 +78,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         from_ts = parse_integer(request, "from_ts", default=0)
         if from_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from_ts must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -86,13 +87,13 @@ class UserMediaStatisticsRestServlet(RestServlet):
         if until_ts is not None:
             if until_ts < 0:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Query parameter until_ts must be a string representing a positive integer.",
                     errcode=Codes.INVALID_PARAM,
                 )
             if until_ts <= from_ts:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Query parameter until_ts must be greater than from_ts.",
                     errcode=Codes.INVALID_PARAM,
                 )
@@ -100,7 +101,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         search_term = parse_string(request, "search_term")
         if search_term == "":
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter search_term cannot be an empty string.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -108,7 +109,9 @@ class UserMediaStatisticsRestServlet(RestServlet):
         direction = parse_string(request, "dir", default="f")
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         users_media, total = await self.store.get_users_media_usage_paginate(
@@ -118,4 +121,4 @@ class UserMediaStatisticsRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(users_media)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index ccd9a2a175..2a60b602b1 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -79,14 +79,14 @@ class UsersRestServletV2(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -122,7 +122,7 @@ class UsersRestServletV2(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class UserRestServletV2(RestServlet):
@@ -172,14 +172,14 @@ class UserRestServletV2(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         ret = await self.admin_handler.get_user(target_user)
 
         if not ret:
             raise NotFoundError("User not found")
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str
@@ -191,7 +191,10 @@ class UserRestServletV2(RestServlet):
         body = parse_json_object_from_request(request)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "This endpoint can only be used with local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "This endpoint can only be used with local users",
+            )
 
         user = await self.admin_handler.get_user(target_user)
         user_id = target_user.to_string()
@@ -210,7 +213,7 @@ class UserRestServletV2(RestServlet):
 
         user_type = body.get("user_type", None)
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
-            raise SynapseError(400, "Invalid user type")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
 
         set_admin_to = body.get("admin", False)
         if not isinstance(set_admin_to, bool):
@@ -223,11 +226,13 @@ class UserRestServletV2(RestServlet):
         password = body.get("password", None)
         if password is not None:
             if not isinstance(password, str) or len(password) > 512:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
         deactivate = body.get("deactivated", False)
         if not isinstance(deactivate, bool):
-            raise SynapseError(400, "'deactivated' parameter is not of type boolean")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
+            )
 
         # convert List[Dict[str, str]] into List[Tuple[str, str]]
         if external_ids is not None:
@@ -282,7 +287,9 @@ class UserRestServletV2(RestServlet):
                         user_id,
                     )
                 except ExternalIDReuseException:
-                    raise SynapseError(409, "External id is already in use.")
+                    raise SynapseError(
+                        HTTPStatus.CONFLICT, "External id is already in use."
+                    )
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
@@ -293,7 +300,9 @@ class UserRestServletV2(RestServlet):
                 if set_admin_to != user["admin"]:
                     auth_user = requester.user
                     if target_user == auth_user and not set_admin_to:
-                        raise SynapseError(400, "You may not demote yourself.")
+                        raise SynapseError(
+                            HTTPStatus.BAD_REQUEST, "You may not demote yourself."
+                        )
 
                     await self.store.set_server_admin(target_user, set_admin_to)
 
@@ -319,7 +328,8 @@ class UserRestServletV2(RestServlet):
                         and self.auth_handler.can_change_password()
                     ):
                         raise SynapseError(
-                            400, "Must provide a password to re-activate an account."
+                            HTTPStatus.BAD_REQUEST,
+                            "Must provide a password to re-activate an account.",
                         )
 
                     await self.deactivate_account_handler.activate_account(
@@ -332,7 +342,7 @@ class UserRestServletV2(RestServlet):
             user = await self.admin_handler.get_user(target_user)
             assert user is not None
 
-            return 200, user
+            return HTTPStatus.OK, user
 
         else:  # create user
             displayname = body.get("displayname", None)
@@ -381,7 +391,9 @@ class UserRestServletV2(RestServlet):
                             user_id,
                         )
                 except ExternalIDReuseException:
-                    raise SynapseError(409, "External id is already in use.")
+                    raise SynapseError(
+                        HTTPStatus.CONFLICT, "External id is already in use."
+                    )
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
@@ -429,51 +441,61 @@ class UserRegisterServlet(RestServlet):
 
         nonce = secrets.token_hex(64)
         self.nonces[nonce] = int(self.reactor.seconds())
-        return 200, {"nonce": nonce}
+        return HTTPStatus.OK, {"nonce": nonce}
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         self._clear_old_nonces()
 
         if not self.hs.config.registration.registration_shared_secret:
-            raise SynapseError(400, "Shared secret registration is not enabled")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled"
+            )
 
         body = parse_json_object_from_request(request)
 
         if "nonce" not in body:
-            raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "nonce must be specified",
+                errcode=Codes.BAD_JSON,
+            )
 
         nonce = body["nonce"]
 
         if nonce not in self.nonces:
-            raise SynapseError(400, "unrecognised nonce")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce")
 
         # Delete the nonce, so it can't be reused, even if it's invalid
         del self.nonces[nonce]
 
         if "username" not in body:
             raise SynapseError(
-                400, "username must be specified", errcode=Codes.BAD_JSON
+                HTTPStatus.BAD_REQUEST,
+                "username must be specified",
+                errcode=Codes.BAD_JSON,
             )
         else:
             if not isinstance(body["username"], str) or len(body["username"]) > 512:
-                raise SynapseError(400, "Invalid username")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
 
             username = body["username"].encode("utf-8")
             if b"\x00" in username:
-                raise SynapseError(400, "Invalid username")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
 
         if "password" not in body:
             raise SynapseError(
-                400, "password must be specified", errcode=Codes.BAD_JSON
+                HTTPStatus.BAD_REQUEST,
+                "password must be specified",
+                errcode=Codes.BAD_JSON,
             )
         else:
             password = body["password"]
             if not isinstance(password, str) or len(password) > 512:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
             password_bytes = password.encode("utf-8")
             if b"\x00" in password_bytes:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
             password_hash = await self.auth_handler.hash(password)
 
@@ -482,10 +504,12 @@ class UserRegisterServlet(RestServlet):
         displayname = body.get("displayname", None)
 
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
-            raise SynapseError(400, "Invalid user type")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
 
         if "mac" not in body:
-            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON
+            )
 
         got_mac = body["mac"]
 
@@ -507,7 +531,7 @@ class UserRegisterServlet(RestServlet):
         want_mac = want_mac_builder.hexdigest()
 
         if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
-            raise SynapseError(403, "HMAC incorrect")
+            raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
 
         # Reuse the parts of RegisterRestServlet to reduce code duplication
         from synapse.rest.client.register import RegisterRestServlet
@@ -524,7 +548,7 @@ class UserRegisterServlet(RestServlet):
         )
 
         result = await register._create_registration_details(user_id, body)
-        return 200, result
+        return HTTPStatus.OK, result
 
 
 class WhoisRestServlet(RestServlet):
@@ -552,11 +576,11 @@ class WhoisRestServlet(RestServlet):
             await assert_user_is_admin(self.auth, auth_user)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only whois a local user")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
 
         ret = await self.admin_handler.get_whois(target_user)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class DeactivateAccountRestServlet(RestServlet):
@@ -575,7 +599,9 @@ class DeactivateAccountRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
 
         if not self.is_mine(UserID.from_string(target_user_id)):
-            raise SynapseError(400, "Can only deactivate local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Can only deactivate local users"
+            )
 
         if not await self.store.get_user_by_id(target_user_id):
             raise NotFoundError("User not found")
@@ -597,7 +623,7 @@ class DeactivateAccountRestServlet(RestServlet):
         else:
             id_server_unbind_result = "no-support"
 
-        return 200, {"id_server_unbind_result": id_server_unbind_result}
+        return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
 
 
 class AccountValidityRenewServlet(RestServlet):
@@ -620,7 +646,7 @@ class AccountValidityRenewServlet(RestServlet):
 
             if "user_id" not in body:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Missing property 'user_id' in the request body",
                 )
 
@@ -631,7 +657,7 @@ class AccountValidityRenewServlet(RestServlet):
             )
 
         res = {"expiration_ts": expiration_ts}
-        return 200, res
+        return HTTPStatus.OK, res
 
 
 class ResetPasswordRestServlet(RestServlet):
@@ -678,7 +704,7 @@ class ResetPasswordRestServlet(RestServlet):
         await self._set_password_handler.set_password(
             target_user_id, new_password_hash, logout_devices, requester
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class SearchUsersRestServlet(RestServlet):
@@ -712,16 +738,16 @@ class SearchUsersRestServlet(RestServlet):
 
         # To allow all users to get the users list
         # if not is_admin and target_user != auth_user:
-        #     raise AuthError(403, "You are not a server admin")
+        #     raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only users a local user")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
 
         term = parse_string(request, "term", required=True)
         logger.info("term: %s ", term)
 
         ret = await self.store.search_users(term)
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class UserAdminServlet(RestServlet):
@@ -765,11 +791,14 @@ class UserAdminServlet(RestServlet):
         target_user = UserID.from_string(user_id)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Only local users can be admins of this homeserver")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Only local users can be admins of this homeserver",
+            )
 
         is_admin = await self.store.is_server_admin(target_user)
 
-        return 200, {"admin": is_admin}
+        return HTTPStatus.OK, {"admin": is_admin}
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str
@@ -785,16 +814,19 @@ class UserAdminServlet(RestServlet):
         assert_params_in_dict(body, ["admin"])
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Only local users can be admins of this homeserver")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Only local users can be admins of this homeserver",
+            )
 
         set_admin_to = bool(body["admin"])
 
         if target_user == auth_user and not set_admin_to:
-            raise SynapseError(400, "You may not demote yourself.")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.")
 
         await self.store.set_server_admin(target_user, set_admin_to)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UserMembershipRestServlet(RestServlet):
@@ -816,7 +848,7 @@ class UserMembershipRestServlet(RestServlet):
 
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class PushersRestServlet(RestServlet):
@@ -845,7 +877,7 @@ class PushersRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -854,7 +886,10 @@ class PushersRestServlet(RestServlet):
 
         filtered_pushers = [p.as_dict() for p in pushers]
 
-        return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+        return HTTPStatus.OK, {
+            "pushers": filtered_pushers,
+            "total": len(filtered_pushers),
+        }
 
 
 class UserTokenRestServlet(RestServlet):
@@ -887,16 +922,22 @@ class UserTokenRestServlet(RestServlet):
         auth_user = requester.user
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be logged in as")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
+            )
 
         body = parse_json_object_from_request(request, allow_empty_body=True)
 
         valid_until_ms = body.get("valid_until_ms")
         if valid_until_ms and not isinstance(valid_until_ms, int):
-            raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
+            )
 
         if auth_user.to_string() == user_id:
-            raise SynapseError(400, "Cannot use admin API to login as self")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self"
+            )
 
         token = await self.auth_handler.create_access_token_for_user_id(
             user_id=auth_user.to_string(),
@@ -905,7 +946,7 @@ class UserTokenRestServlet(RestServlet):
             puppets_user_id=user_id,
         )
 
-        return 200, {"access_token": token}
+        return HTTPStatus.OK, {"access_token": token}
 
 
 class ShadowBanRestServlet(RestServlet):
@@ -947,11 +988,13 @@ class ShadowBanRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be shadow-banned")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+            )
 
         await self.store.set_shadow_banned(UserID.from_string(user_id), True)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -959,11 +1002,13 @@ class ShadowBanRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be shadow-banned")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+            )
 
         await self.store.set_shadow_banned(UserID.from_string(user_id), False)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class RateLimitRestServlet(RestServlet):
@@ -995,7 +1040,7 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -1016,7 +1061,7 @@ class RateLimitRestServlet(RestServlet):
         else:
             ret = {}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -1024,7 +1069,9 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be ratelimited")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+            )
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -1036,14 +1083,14 @@ class RateLimitRestServlet(RestServlet):
 
         if not isinstance(messages_per_second, int) or messages_per_second < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (messages_per_second,),
                 errcode=Codes.INVALID_PARAM,
             )
 
         if not isinstance(burst_count, int) or burst_count < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (burst_count,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -1059,7 +1106,7 @@ class RateLimitRestServlet(RestServlet):
             "burst_count": ratelimit.burst_count,
         }
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -1067,11 +1114,13 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be ratelimited")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+            )
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
 
         await self.store.delete_ratelimit_for_user(user_id)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 67e03dca04..09f378f919 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -14,7 +14,17 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from typing_extensions import TypedDict
 
@@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -155,11 +164,14 @@ class LoginRestServlet(RestServlet):
         login_submission = parse_json_object_from_request(request)
 
         if self._msc2918_enabled:
-            # Check if this login should also issue a refresh token, as per
-            # MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+            # Check if this login should also issue a refresh token, as per MSC2918
+            should_issue_refresh_token = login_submission.get(
+                "org.matrix.msc2918.refresh_token", False
             )
+            if not isinstance(should_issue_refresh_token, bool):
+                raise SynapseError(
+                    400, "`org.matrix.msc2918.refresh_token` should be true or false."
+                )
         else:
             should_issue_refresh_token = False
 
@@ -458,6 +470,7 @@ class RefreshTokenServlet(RestServlet):
         self.refreshable_access_token_lifetime = (
             hs.config.registration.refreshable_access_token_lifetime
         )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         refresh_submission = parse_json_object_from_request(request)
@@ -467,22 +480,33 @@ class RefreshTokenServlet(RestServlet):
         if not isinstance(token, str):
             raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
 
-        valid_until_ms = (
-            self._clock.time_msec() + self.refreshable_access_token_lifetime
-        )
-        access_token, refresh_token = await self._auth_handler.refresh_token(
-            token, valid_until_ms
-        )
-        expires_in_ms = valid_until_ms - self._clock.time_msec()
-        return (
-            200,
-            {
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "expires_in_ms": expires_in_ms,
-            },
+        now = self._clock.time_msec()
+        access_valid_until_ms = None
+        if self.refreshable_access_token_lifetime is not None:
+            access_valid_until_ms = now + self.refreshable_access_token_lifetime
+        refresh_valid_until_ms = None
+        if self.refresh_token_lifetime is not None:
+            refresh_valid_until_ms = now + self.refresh_token_lifetime
+
+        (
+            access_token,
+            refresh_token,
+            actual_access_token_expiry,
+        ) = await self._auth_handler.refresh_token(
+            token, access_valid_until_ms, refresh_valid_until_ms
         )
 
+        response: Dict[str, Union[str, int]] = {
+            "access_token": access_token,
+            "refresh_token": refresh_token,
+        }
+
+        # expires_in_ms is only present if the token expires
+        if actual_access_token_expiry is not None:
+            response["expires_in_ms"] = actual_access_token_expiry - now
+
+        return 200, response
+
 
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index d2b11e39d9..11fd6cd24d 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_json_object_from_request,
     parse_string,
 )
@@ -449,9 +448,13 @@ class RegisterRestServlet(RestServlet):
         if self._msc2918_enabled:
             # Check if this registration should also issue a refresh token, as
             # per MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name="org.matrix.msc2918.refresh_token", default=False
+            should_issue_refresh_token = body.get(
+                "org.matrix.msc2918.refresh_token", False
             )
+            if not isinstance(should_issue_refresh_token, bool):
+                raise SynapseError(
+                    400, "`org.matrix.msc2918.refresh_token` should be true or false."
+                )
         else:
             should_issue_refresh_token = False
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 955d4e8641..73d0f7c950 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -1138,12 +1138,12 @@ class RoomSpaceSummaryRestServlet(RestServlet):
 
 
 class RoomHierarchyRestServlet(RestServlet):
-    PATTERNS = (
+    PATTERNS = [
         re.compile(
-            "^/_matrix/client/unstable/org.matrix.msc2946"
+            "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
             "/rooms/(?P<room_id>[^/]*)/hierarchy$"
         ),
-    )
+    ]
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -1168,7 +1168,7 @@ class RoomHierarchyRestServlet(RestServlet):
             )
 
         return 200, await self._room_summary_handler.get_room_hierarchy(
-            requester.user.to_string(),
+            requester,
             room_id,
             suggested_only=parse_boolean(request, "suggested_only", default=False),
             max_depth=max_depth,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
     store: "DataStore"
 
     def get_events(
-        self, event_ids: Iterable[str], allow_rejected: bool = False
+        self, event_ids: Collection[str], allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
 from typing import (
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+    state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
 ) -> StateMap[str]:
     """
     Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         self,
         stream_name: str,
         instance_name: str,
-        token: StreamToken,
+        token: int,
         rows: Iterable[Any],
     ) -> None:
         pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index bc8364400d..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+)
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.types import Connection
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
 
 from . import engines
 
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+    """A handler for a given background update.
+
+    Attributes:
+        callback: The function to call to make progress on the background
+            update.
+        oneshot: Wether the update is likely to happen all in one go, ignoring
+            the supplied target duration, e.g. index creation. This is used by
+            the update controller to help correctly schedule the update.
+    """
+
+    callback: Callable[[JsonDict, int], Awaitable[int]]
+    oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+    BACKGROUND_UPDATE_INTERVAL_MS = 1000
+    BACKGROUND_UPDATE_DURATION_MS = 100
+
+    def __init__(self, sleep: bool, clock: Clock):
+        self._sleep = sleep
+        self._clock = clock
+
+    async def __aenter__(self) -> int:
+        if self._sleep:
+            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+        return self.BACKGROUND_UPDATE_DURATION_MS
+
+    async def __aexit__(self, *exc) -> None:
+        pass
+
+
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
@@ -84,20 +133,22 @@ class BackgroundUpdater:
 
     MINIMUM_BACKGROUND_BATCH_SIZE = 1
     DEFAULT_BACKGROUND_BATCH_SIZE = 100
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
 
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
+        self._database_name = database.name()
+
         # if a background update is currently running, its name.
         self._current_background_update: Optional[str] = None
 
+        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
         self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
-        self._background_update_handlers: Dict[
-            str, Callable[[JsonDict, int], Awaitable[int]]
-        ] = {}
+        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
         self._all_done = False
 
         # Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+    def register_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Register callbacks from a module for each hook."""
+        if self._on_update_callback is not None:
+            logger.warning(
+                "More than one module tried to register callbacks for controlling"
+                " background updates. Only the callbacks registered by the first module"
+                " (in order of appearance in Synapse's configuration file) that tried to"
+                " do so will be called."
+            )
+
+            return
+
+        self._on_update_callback = on_update
+
+        if default_batch_size is not None:
+            self._default_batch_size_callback = default_batch_size
+
+        if min_batch_size is not None:
+            self._min_batch_size_callback = min_batch_size
+
+    def _get_context_manager_for_update(
+        self,
+        sleep: bool,
+        update_name: str,
+        database_name: str,
+        oneshot: bool,
+    ) -> AsyncContextManager[int]:
+        """Get a context manager to run a background update with.
+
+        If a module has registered a `update_handler` callback, use the context manager
+        it returns.
+
+        Otherwise, returns a context manager that will return a default value, optionally
+        sleeping if needed.
+
+        Args:
+            sleep: Whether we can sleep between updates.
+            update_name: The name of the update.
+            database_name: The name of the database the update is being run on.
+            oneshot: Whether the update will complete all in one go, e.g. index creation.
+                In such cases the returned target duration is ignored.
+
+        Returns:
+            The target duration in milliseconds that the background update should run for.
+
+            Note: this is a *target*, and an iteration may take substantially longer or
+            shorter.
+        """
+        if self._on_update_callback is not None:
+            return self._on_update_callback(update_name, database_name, oneshot)
+
+        return _BackgroundUpdateContextManager(sleep, self._clock)
+
+    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+        """The batch size to use for the first iteration of a new background
+        update.
+        """
+        if self._default_batch_size_callback is not None:
+            return await self._default_batch_size_callback(update_name, database_name)
+
+        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+        """A lower bound on the batch size of a new background update.
+
+        Used to ensure that progress is always made. Must be greater than 0.
+        """
+        if self._min_batch_size_callback is not None:
+            return await self._min_batch_size_callback(update_name, database_name)
+
+        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
 
@@ -135,13 +263,8 @@ class BackgroundUpdater:
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
-                if sleep:
-                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
                 try:
-                    result = await self.do_next_background_update(
-                        self.BACKGROUND_UPDATE_DURATION_MS
-                    )
+                    result = await self.do_next_background_update(sleep)
                 except Exception:
                     logger.exception("Error doing update")
                 else:
@@ -203,13 +326,15 @@ class BackgroundUpdater:
 
         return not update_exists
 
-    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+    async def do_next_background_update(self, sleep: bool = True) -> bool:
         """Does some amount of work on the next queued background update
 
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms: How long we want to spend updating.
+            sleep: Whether to limit how quickly we run background updates or
+                not.
+
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -252,7 +377,19 @@ class BackgroundUpdater:
 
             self._current_background_update = upd["update_name"]
 
-        await self._do_background_update(desired_duration_ms)
+        # We have a background update to run, otherwise we would have returned
+        # early.
+        assert self._current_background_update is not None
+        update_info = self._background_update_handlers[self._current_background_update]
+
+        async with self._get_context_manager_for_update(
+            sleep=sleep,
+            update_name=self._current_background_update,
+            database_name=self._database_name,
+            oneshot=update_info.oneshot,
+        ) as desired_duration_ms:
+            await self._do_background_update(desired_duration_ms)
+
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -260,7 +397,7 @@ class BackgroundUpdater:
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
-        update_handler = self._background_update_handlers[update_name]
+        update_handler = self._background_update_handlers[update_name].callback
 
         performance = self._background_update_performance.get(update_name)
 
@@ -273,9 +410,14 @@ class BackgroundUpdater:
         if items_per_ms is not None:
             batch_size = int(desired_duration_ms * items_per_ms)
             # Clamp the batch size so that we always make progress
-            batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+            batch_size = max(
+                batch_size,
+                await self._min_batch_size(update_name, self._database_name),
+            )
         else:
-            batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+            batch_size = await self._default_batch_size(
+                update_name, self._database_name
+            )
 
         progress_json = await self.db_pool.simple_select_one_onecol(
             "background_updates",
@@ -294,6 +436,8 @@ class BackgroundUpdater:
 
         duration_ms = time_stop - time_start
 
+        performance.update(items_updated, duration_ms)
+
         logger.info(
             "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -306,8 +450,6 @@ class BackgroundUpdater:
             batch_size,
         )
 
-        performance.update(items_updated, duration_ms)
-
         return len(self._background_update_performance)
 
     def register_background_update_handler(
@@ -331,7 +473,9 @@ class BackgroundUpdater:
             update_name: The name of the update that this code handles.
             update_handler: The function that does the update.
         """
-        self._background_update_handlers[update_name] = update_handler
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            update_handler
+        )
 
     def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
@@ -453,7 +597,9 @@ class BackgroundUpdater:
             await self._end_background_update(update_name)
             return 1
 
-        self.register_background_update_handler(update_name, updater)
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
 ]
 
 
+class BasePushAction(TypedDict):
+    event_id: str
+    actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+    room_id: str
+    stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+    received_ts: Optional[int]
+
+
 def _serialize_action(actions, is_highlight):
     """Custom serializer for actions. This allows us to "compress" common actions.
 
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[HttpPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the httppusher.
 
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[EmailPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the emailpusher
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..4171b904eb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import itertools
 import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
 )
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-        # Ideally we'd move these ID gens here, unfortunately some other ID
-        # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
-        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
+        # Since we have been configured to write, we ought to have id generators,
+        # rather than id trackers.
+        assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+        assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
+        *,
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
-        backfilled: bool = False,
+        use_negative_stream_ordering: bool = False,
+        inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
                 room state
             new_forward_extremities: Map from room_id to list of event IDs
                 that are the new forward extremities of the room.
-            backfilled
+            use_negative_stream_ordering: Whether to start stream_ordering on
+                the negative side and decrement. This should be set as True
+                for backfilled events because backfilled events get a negative
+                stream ordering so they don't come down incremental `/sync`.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
 
         Returns:
             Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
         #
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
-        if backfilled:
+        if use_negative_stream_ordering:
             stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
@@ -176,13 +188,13 @@ class PersistEventsStore:
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
+                inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
-            if not backfilled:
+            if stream < 0:
                 # backfilled events have negative stream orderings, so we don't
                 # want to set the event_persisted_position to that.
                 synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
     def _persist_events_txn(
         self,
         txn: LoggingTransaction,
+        *,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        inhibit_local_membership_updates: bool = False,
         state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
         new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
     ):
@@ -330,7 +343,10 @@ class PersistEventsStore:
         Args:
             txn
             events_and_contexts: events to persist
-            backfilled: True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(
-            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
-        )
+        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
             txn,
             events_and_contexts=events_and_contexts,
             all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # We call this last as it assumes we've inserted the events into
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
         self,
         txn,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
     ):
         """Update min_depth for each room
 
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
-            backfilled (bool): True if the events were backfilled
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
-            if not backfilled:
+            # Then update the `stream_ordering` position to mark the latest
+            # event as the front of the room. This should not be done for
+            # backfilled events because backfilled events have negative
+            # stream_ordering and happened in the past so we know that we don't
+            # need to update the stream_ordering tip/front for the room.
+            assert event.internal_metadata.stream_ordering is not None
+            if event.internal_metadata.stream_ordering >= 0:
                 txn.call_after(
                     self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _update_metadata_tables_txn(
-        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+        self,
+        txn,
+        *,
+        events_and_contexts,
+        all_events_and_contexts,
+        inhibit_local_membership_updates: bool = False,
     ):
         """Update all the miscellaneous tables for new events
 
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
                 events that we were going to persist. This includes events
                 we've already persisted, etc, that wouldn't appear in
                 events_and_context.
-            backfilled (bool): True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
         """
 
         # Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
                 for event, _ in events_and_contexts
                 if event.type == EventTypes.Member
             ],
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
         for row in rows:
             event = ev_map[row["event_id"]]
             if not row["rejects"] and not row["redacts"]:
-                to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+                to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.set(
+                    (cache_entry.event.event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
             txn, table="event_reference_hashes", values=vals
         )
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database."""
+    def _store_room_members_txn(
+        self, txn, events, *, inhibit_local_membership_updates: bool = False
+    ):
+        """
+        Store a room member in the database.
+        Args:
+            txn: The transaction to use.
+            events: List of events to store.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
+        """
 
         def non_null_str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
             # band membership", like a remote invite or a rejection of a remote invite.
             if (
                 self.is_mine_id(event.state_key)
-                and not backfilled
+                and not inhibit_local_membership_updates
                 and event.internal_metadata.is_outlier()
                 and event.internal_metadata.is_out_of_band_membership()
             ):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
 import logging
 import threading
 from typing import (
+    TYPE_CHECKING,
+    Any,
     Collection,
     Container,
     Dict,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
+    RoomVersion,
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
 # control how we batch/bulk fetch events from the database.
 # The values are plucked out of thing air to make initial sync run faster
 # on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
     event: EventBase
     redacted_event: Optional[EventBase]
 
@@ -129,7 +145,7 @@ class _EventRow:
     json: str
     internal_metadata: str
     format_version: Optional[int]
-    room_version_id: Optional[int]
+    room_version_id: Optional[str]
     rejected_reason: Optional[str]
     redactions: List[str]
     outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
     # options controlling this.
     USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
+        self._stream_id_gen: AbstractStreamIdTracker
+        self._backfill_id_gen: AbstractStreamIdTracker
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache = LruCache(
+        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
         # ID to cache entry. Note that the returned dict may not have the
         # requested event in it if the event isn't in the DB.
         self._current_event_fetches: Dict[
-            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+            str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
         self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
+        self._event_fetch_list: List[
+            Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+        ] = []
         self._event_fetch_ongoing = 0
         event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
-        def get_chain_id_txn(txn):
+        def get_chain_id_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self.event_chain_id_gen = build_sequence_generator(
             db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
             id_column="chain_id",
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[False] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[False] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> EventBase:
         ...
 
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[True] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[True] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> Optional[EventBase]:
         ...
 
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_events(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
         # same dict into itself N times).
         already_fetching_ids: Set[str] = set()
         already_fetching_deferreds: Set[
-            ObservableDeferred[Dict[str, _EventCacheEntry]]
+            ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = set()
 
         for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
             # function returning more events than requested, but that can happen
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
-                Dict[str, _EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred())
+                Dict[str, EventCacheEntry]
+            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
 
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id):
+    def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
 
         May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn: Connection) -> None:
+    def _maybe_start_fetch_thread(self) -> None:
+        """Starts an event fetch thread if we are not yet at the maximum number."""
+        with self._event_fetch_lock:
+            if (
+                self._event_fetch_list
+                and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+            ):
+                self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+                # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process("fetch_events", self._fetch_thread)
+
+    async def _fetch_thread(self) -> None:
+        """Services requests for events from `_event_fetch_list`."""
+        exc = None
+        try:
+            await self.db_pool.runWithConnection(self._fetch_loop)
+        except BaseException as e:
+            exc = e
+            raise
+        finally:
+            should_restart = False
+            event_fetches_to_fail = []
+            with self._event_fetch_lock:
+                self._event_fetch_ongoing -= 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+                # There may still be work remaining in `_event_fetch_list` if we
+                # failed, or it was added in between us deciding to exit and
+                # decrementing `_event_fetch_ongoing`.
+                if self._event_fetch_list:
+                    if exc is None:
+                        # We decided to exit, but then some more work was added
+                        # before `_event_fetch_ongoing` was decremented.
+                        # If a new event fetch thread was not started, we should
+                        # restart ourselves since the remaining event fetch threads
+                        # may take a while to get around to the new work.
+                        #
+                        # Unfortunately it is not possible to tell whether a new
+                        # event fetch thread was started, so we restart
+                        # unconditionally. If we are unlucky, we will end up with
+                        # an idle fetch thread, but it will time out after
+                        # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+                        # in any case.
+                        #
+                        # Note that multiple fetch threads may run down this path at
+                        # the same time.
+                        should_restart = True
+                    elif isinstance(exc, Exception):
+                        if self._event_fetch_ongoing == 0:
+                            # We were the last remaining fetcher and failed.
+                            # Fail any outstanding fetches since no one else will
+                            # handle them.
+                            event_fetches_to_fail = self._event_fetch_list
+                            self._event_fetch_list = []
+                        else:
+                            # We weren't the last remaining fetcher, so another
+                            # fetcher will pick up the work. This will either happen
+                            # after their existing work, however long that takes,
+                            # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+                            # they are idle.
+                            pass
+                    else:
+                        # The exception is a `SystemExit`, `KeyboardInterrupt` or
+                        # `GeneratorExit`. Don't try to do anything clever here.
+                        pass
+
+            if should_restart:
+                # We exited cleanly but noticed more work.
+                self._maybe_start_fetch_thread()
+
+            if event_fetches_to_fail:
+                # We were the last remaining fetcher and failed.
+                # Fail any outstanding fetches since no one else will handle them.
+                assert exc is not None
+                with PreserveLoggingContext():
+                    for _, deferred in event_fetches_to_fail:
+                        deferred.errback(exc)
+
+    def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        try:
-            i = 0
-            while True:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if (
-                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                            or single_threaded
-                            or i > EVENT_QUEUE_ITERATIONS
-                        ):
-                            break
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    # There are no requests waiting. If we haven't yet reached the
+                    # maximum iteration limit, wait for some more requests to turn up.
+                    # Otherwise, bail out.
+                    single_threaded = self.database_engine.single_threaded
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
+                        return
+
+                    self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                    i += 1
+                    continue
+                i = 0
 
-                self._fetch_event_list(conn, event_list)
-        finally:
-            self._event_fetch_ongoing -= 1
-            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+            self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
-        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+        self,
+        conn: LoggingDatabaseConnection,
+        event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
     ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire():
+                def fire() -> None:
                     for _, d in event_list:
                         d.callback(row_dict)
 
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
                 logger.exception("do_fetch")
 
                 # We only want to resolve deferreds from the main thread
-                def fire(evs, exc):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(exc)
+                def fire_errback(exc: Exception) -> None:
+                    for _, d in event_list:
+                        d.errback(exc)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, e)
+                    self.hs.get_reactor().callFromThread(fire_errback, e)
 
     async def _get_events_from_db(
-        self, event_ids: Iterable[str]
-    ) -> Dict[str, _EventCacheEntry]:
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
         May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
-        fetched_events = {}
+        fetched_event_ids: Set[str] = set()
+        fetched_events: Dict[str, _EventRow] = {}
         events_to_fetch = event_ids
 
         while events_to_fetch:
             row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
-            redaction_ids = set()
+            redaction_ids: Set[str] = set()
             for event_id in events_to_fetch:
                 row = row_map.get(event_id)
-                fetched_events[event_id] = row
+                fetched_event_ids.add(event_id)
                 if row:
+                    fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
         # build a map from event_id to EventBase
-        event_map = {}
+        event_map: Dict[str, EventBase] = {}
         for event_id, row in fetched_events.items():
-            if not row:
-                continue
             assert row.event_id == event_id
 
             rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             room_version_id = row.room_version_id
 
+            room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
                 # arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
-        result_map = {}
+        result_map: Dict[str, EventCacheEntry] = {}
         for event_id, original_ev in event_map.items():
             redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
 
-            cache_entry = _EventCacheEntry(
+            cache_entry = EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
 
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+    async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
             that weren't requested.
         """
 
-        events_d = defer.Deferred()
+        events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
-
             self._event_fetch_lock.notify()
 
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            run_as_background_process(
-                "fetch_events", self.db_pool.runWithConnection, self._do_fetch
-            )
+        self._maybe_start_fetch_thread()
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    async def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
             event_ids: events we are looking for
 
         Returns:
-            set[str]: The events we have already seen.
+            The set of events we have already seen.
         """
         res = await self._have_seen_events_dict(
             (room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
         }
         results = {x: True for x in cache_results}
 
-        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+        def have_seen_events_txn(
+            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+        ) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
         return results
 
     @cached(max_entries=100000, tree=True)
-    async def have_seen_event(self, room_id: str, event_id: str):
+    async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
         # this only exists for the benefit of the @cachedList descriptor on
         # _have_seen_events_dict
         raise NotImplementedError()
 
-    def _get_current_state_event_counts_txn(self, txn, room_id):
+    def _get_current_state_event_counts_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> int:
         """
         See get_current_state_event_counts.
         """
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    async def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
         more resources.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            dict[str:int] of complexity version to complexity.
+            dict[str:float] of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self):
+    def get_current_events_token(self) -> int:
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1295,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_all_new_forward_event_rows(txn):
+        def get_all_new_forward_event_rows(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1332,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_ex_outlier_stream_rows_txn(txn):
+        def get_ex_outlier_stream_rows_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_backfill_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
@@ -1386,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_backfill_event_rows(txn):
+        def get_all_new_backfill_event_rows(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
-            new_event_updates = [(row[0], row[1:]) for row in txn]
+            new_event_updates: List[
+                Tuple[int, Tuple[str, str, str, str, str, str]]
+            ] = []
+            row: Tuple[int, str, str, str, str, str, str]
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             limited = False
             if len(new_event_updates) == limit:
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
-            new_event_updates.extend((row[0], row[1:]) for row in txn)
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             if len(new_event_updates) >= limit:
                 upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_current_state_deltas(
         self, instance_name: str, from_token: int, to_token: int, target_row_count: int
-    ) -> Tuple[List[Tuple], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
         """Fetch updates from current_state_delta_stream
 
         Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
                * `limited` is whether there are more updates to fetch.
         """
 
-        def get_all_updated_current_state_deltas_txn(txn):
+        def get_all_updated_current_state_deltas_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
                 ORDER BY stream_id ASC LIMIT ?
             """
             txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
-        def get_deltas_for_stream_id_txn(txn, stream_id):
+        def get_deltas_for_stream_id_txn(
+            txn: LoggingTransaction, stream_id: int
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE stream_id = ?
             """
             txn.execute(sql, [stream_id])
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows: List[Tuple] = await self.db_pool.runInteraction(
+        rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    async def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
         """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cached(max_entries=5000)
-    async def get_event_ordering(self, event_id):
+    async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
         res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
             None otherwise.
         """
 
-        def get_next_event_to_expire_txn(txn):
+        def get_next_event_to_expire_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, int]]:
             txn.execute(
                 """
                 SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
                 """
             )
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
         return mapping
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
-    async def _cleanup_old_transaction_ids(self):
+    async def _cleanup_old_transaction_ids(self) -> None:
         """Cleans out transaction id mappings older than 24hrs."""
 
-        def _cleanup_old_transaction_ids_txn(txn):
+        def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen: Union[
-                StreamIdGenerator, SlavedIdTracker
-            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0e8c168667..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
     has_next_access_token_been_used: bool
     """True if the next access token was already used at least once."""
 
+    expiry_ts: Optional[int]
+    """The time at which the refresh token expires and can not be used.
+    If None, the refresh token doesn't expire."""
+
+    ultimate_session_expiry_ts: Optional[int]
+    """The time at which the session comes to an end and can no longer be
+    refreshed.
+    If None, the session can be refreshed indefinitely."""
+
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
@@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     rt.user_id,
                     rt.device_id,
                     rt.next_token_id,
-                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
-                    at.used has_next_access_token_been_used
+                    (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+                    at.used AS has_next_access_token_been_used,
+                    rt.expiry_ts,
+                    rt.ultimate_session_expiry_ts
                 FROM refresh_tokens rt
                 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
                 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 has_next_refresh_token_been_refreshed=row[4],
                 # This column is nullable, ensure it's a boolean
                 has_next_access_token_been_used=(row[5] or False),
+                expiry_ts=row[6],
+                ultimate_session_expiry_ts=row[7],
             )
 
         return await self.db_pool.runInteraction(
@@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         user_id: str,
         token: str,
         device_id: Optional[str],
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> int:
         """Adds a refresh token for the given user.
 
@@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             user_id: The user ID.
             token: The new access token to add.
             device_id: ID of the device to associate with the refresh token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "device_id": device_id,
                 "token": token,
                 "next_token_id": None,
+                "expiry_ts": expiry_ts,
+                "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
             },
             desc="add_refresh_token_to_user",
         )
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 402f134d89..428d66a617 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -583,7 +583,8 @@ class EventsPersistenceStorage:
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
-                backfilled=backfilled,
+                use_negative_stream_ordering=backfilled,
+                inhibit_local_membership_updates=backfilled,
             )
 
             await self._handle_potentially_left_users(potentially_left_users)
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+  -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+  -- They may not be used after they have expired.
+  -- If null, then the refresh token's lifetime is unlimited.
+  ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+  -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+  -- No matter how much the access and refresh tokens are refreshed, they cannot
+  -- be extended past this time.
+  -- If null, then the session length is unlimited.
+  ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def get_next(self) -> AsyncContextManager[int]:
-        raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+    """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+    Stream IDs are monotonically increasing or decreasing integers representing write
+    transactions. The "current" stream ID is the stream ID such that all transactions
+    with equal or smaller stream IDs have completed. Since transactions may complete out
+    of order, this is not the same as the stream ID of the last completed transaction.
+
+    Completed transactions include both committed transactions and transactions that
+    have been rolled back.
+    """
 
     @abc.abstractmethod
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+    def advance(self, instance_name: str, new_id: int) -> None:
+        """Advance the position of the named writer to the given ID, if greater
+        than existing entry.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+
+        Returns:
+            The maximum stream id.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to `get_current_token`.
+        """
+        raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+    """Generates stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
+
+    See `AbstractStreamIdTracker` for more details.
+    """
+
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        """
+        Usage:
+            async with stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        """
+        Usage:
+            async with stream_id_gen.get_next(n) as stream_ids:
+                # ... persist events ...
+        """
         raise NotImplementedError()
 
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Used to generate new stream ids when persisting events while keeping
-    track of which transactions have been completed.
+    """Generates and tracks stream IDs for a stream with a single writer.
 
-    This allows us to get the "current" stream id, i.e. the stream id such that
-    all ids less than or equal to it have completed. This handles the fact that
-    persistence of events can complete out of order.
+    This class must only be used when the current Synapse process is the sole
+    writer for a stream.
 
     Args:
         db_conn(connection):  A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+    def advance(self, instance_name: str, new_id: int) -> None:
+        # `StreamIdGenerator` should only be used when there is a single writer,
+        # so replication should never happen.
+        raise Exception("Replication is not supported by StreamIdGenerator")
+
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
         with self._lock:
             self._current += self._step
             next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next(n) as stream_ids:
-                # ... persist events ...
-        """
         with self._lock:
             next_ids = range(
                 self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-
-        Returns:
-            The maximum stream id.
-        """
         with self._lock:
             if self._unfinished_ids:
                 return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
             return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
 
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
-    """An ID generator that tracks a stream that can have multiple writers.
+    """Generates and tracks stream IDs for a stream with multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other
     writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next_mult(5) as stream_ids:
-                # ... persist events ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer."""
-
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
         #
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             }
 
     def advance(self, instance_name: str, new_id: int) -> None:
-        """Advance the position of the named writer to the given ID, if greater
-        than existing entry.
-        """
-
         new_id *= self._return_factor
 
         with self._lock: