summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/constants.py198
-rw-r--r--synapse/api/urls.py3
-rw-r--r--synapse/app/_base.py5
-rw-r--r--synapse/app/generic_worker.py14
-rw-r--r--synapse/app/homeserver.py18
-rw-r--r--synapse/appservice/__init__.py3
-rw-r--r--synapse/appservice/api.py23
-rw-r--r--synapse/config/__main__.py3
-rw-r--r--synapse/config/_base.py157
-rw-r--r--synapse/config/_base.pyi87
-rw-r--r--synapse/config/appservice.py23
-rw-r--r--synapse/config/cache.py30
-rw-r--r--synapse/config/cas.py5
-rw-r--r--synapse/config/database.py13
-rw-r--r--synapse/config/emailconfig.py33
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/jwt.py9
-rw-r--r--synapse/config/key.py3
-rw-r--r--synapse/config/logger.py26
-rw-r--r--synapse/config/oidc.py58
-rw-r--r--synapse/config/registration.py133
-rw-r--r--synapse/config/repository.py9
-rw-r--r--synapse/config/room_directory.py53
-rw-r--r--synapse/config/saml2.py21
-rw-r--r--synapse/config/server.py24
-rw-r--r--synapse/config/sso.py12
-rw-r--r--synapse/config/tls.py59
-rw-r--r--synapse/config/user_directory.py4
-rw-r--r--synapse/config/workers.py4
-rw-r--r--synapse/crypto/keyring.py99
-rw-r--r--synapse/events/snapshot.py5
-rw-r--r--synapse/events/utils.py170
-rw-r--r--synapse/federation/federation_client.py112
-rw-r--r--synapse/federation/federation_server.py61
-rw-r--r--synapse/federation/persistence.py4
-rw-r--r--synapse/federation/send_queue.py25
-rw-r--r--synapse/federation/sender/per_destination_queue.py13
-rw-r--r--synapse/federation/transport/client.py91
-rw-r--r--synapse/federation/transport/server/__init__.py14
-rw-r--r--synapse/federation/transport/server/_base.py48
-rw-r--r--synapse/federation/transport/server/federation.py47
-rw-r--r--synapse/groups/attestations.py4
-rw-r--r--synapse/handlers/auth.py158
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/federation.py61
-rw-r--r--synapse/handlers/identity.py18
-rw-r--r--synapse/handlers/initial_sync.py30
-rw-r--r--synapse/handlers/message.py62
-rw-r--r--synapse/handlers/oidc.py58
-rw-r--r--synapse/handlers/register.py84
-rw-r--r--synapse/handlers/room.py149
-rw-r--r--synapse/handlers/room_batch.py2
-rw-r--r--synapse/handlers/room_member.py15
-rw-r--r--synapse/handlers/room_summary.py23
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--synapse/handlers/sync.py156
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/http/servlet.py29
-rw-r--r--synapse/metrics/__init__.py101
-rw-r--r--synapse/metrics/_exposition.py34
-rw-r--r--synapse/metrics/background_process_metrics.py78
-rw-r--r--synapse/metrics/jemalloc.py10
-rw-r--r--synapse/module_api/__init__.py298
-rw-r--r--synapse/push/emailpusher.py10
-rw-r--r--synapse/push/httppusher.py3
-rw-r--r--synapse/push/mailer.py72
-rw-r--r--synapse/push/push_rule_evaluator.py7
-rw-r--r--synapse/push/push_types.py136
-rw-r--r--synapse/python_dependencies.py3
-rw-r--r--synapse/replication/http/login.py8
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py22
-rw-r--r--synapse/replication/slave/storage/push_rule.py4
-rw-r--r--synapse/replication/tcp/streams/events.py6
-rw-r--r--synapse/rest/admin/__init__.py29
-rw-r--r--synapse/rest/admin/_base.py3
-rw-r--r--synapse/rest/admin/background_updates.py123
-rw-r--r--synapse/rest/admin/devices.py21
-rw-r--r--synapse/rest/admin/event_reports.py21
-rw-r--r--synapse/rest/admin/federation.py135
-rw-r--r--synapse/rest/admin/groups.py5
-rw-r--r--synapse/rest/admin/media.py53
-rw-r--r--synapse/rest/admin/registration_tokens.py51
-rw-r--r--synapse/rest/admin/rooms.py144
-rw-r--r--synapse/rest/admin/server_notice_servlet.py11
-rw-r--r--synapse/rest/admin/statistics.py21
-rw-r--r--synapse/rest/admin/users.py175
-rw-r--r--synapse/rest/client/_base.py4
-rw-r--r--synapse/rest/client/keys.py2
-rw-r--r--synapse/rest/client/login.py97
-rw-r--r--synapse/rest/client/register.py22
-rw-r--r--synapse/rest/client/relations.py14
-rw-r--r--synapse/rest/client/room.py67
-rw-r--r--synapse/rest/client/sync.py6
-rw-r--r--synapse/rest/media/v1/_base.py18
-rw-r--r--synapse/rest/media/v1/filepath.py270
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/state/v1.py3
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/background_updates.py196
-rw-r--r--synapse/storage/database.py6
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/deviceinbox.py211
-rw-r--r--synapse/storage/databases/main/devices.py50
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py51
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py19
-rw-r--r--synapse/storage/databases/main/events.py136
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py88
-rw-r--r--synapse/storage/databases/main/events_worker.py581
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py11
-rw-r--r--synapse/storage/databases/main/registration.py80
-rw-r--r--synapse/storage/databases/main/relations.py67
-rw-r--r--synapse/storage/databases/main/room.py32
-rw-r--r--synapse/storage/databases/main/roommember.py4
-rw-r--r--synapse/storage/databases/main/stream.py15
-rw-r--r--synapse/storage/databases/main/transactions.py70
-rw-r--r--synapse/storage/persist_events.py3
-rw-r--r--synapse/storage/prepare_database.py40
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql34
-rw-r--r--synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql (renamed from synapse/storage/schema/main/delta/65/02_thread_relations.sql)2
-rw-r--r--synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql (renamed from synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql)8
-rw-r--r--synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql28
-rw-r--r--synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql27
-rw-r--r--synapse/storage/util/id_generators.py116
-rw-r--r--synapse/util/__init__.py59
-rw-r--r--synapse/util/async_helpers.py32
-rw-r--r--synapse/util/caches/__init__.py32
-rw-r--r--synapse/util/caches/deferred_cache.py11
-rw-r--r--synapse/util/caches/descriptors.py67
-rw-r--r--synapse/util/caches/expiringcache.py12
-rw-r--r--synapse/util/caches/lrucache.py42
-rw-r--r--synapse/util/distributor.py11
-rw-r--r--synapse/util/gai_resolver.py75
-rw-r--r--synapse/util/linked_list.py4
-rw-r--r--synapse/util/metrics.py12
-rw-r--r--synapse/util/stringutils.py21
-rw-r--r--synapse/util/versionstring.py82
142 files changed, 5000 insertions, 2041 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 06b179a7e8..3cd1ce6070 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.47.0rc2"
+__version__ = "1.48.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a33ac34161..f7d29b4319 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
 
 """Contains constants from the specification."""
 
+from typing_extensions import Final
+
 # the max size of a (canonical-json-encoded) event
 MAX_PDU_SIZE = 65536
 
@@ -39,125 +41,125 @@ class Membership:
 
     """Represents the membership states of a user in a room."""
 
-    INVITE = "invite"
-    JOIN = "join"
-    KNOCK = "knock"
-    LEAVE = "leave"
-    BAN = "ban"
-    LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
+    INVITE: Final = "invite"
+    JOIN: Final = "join"
+    KNOCK: Final = "knock"
+    LEAVE: Final = "leave"
+    BAN: Final = "ban"
+    LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
 
 
 class PresenceState:
     """Represents the presence state of a user."""
 
-    OFFLINE = "offline"
-    UNAVAILABLE = "unavailable"
-    ONLINE = "online"
-    BUSY = "org.matrix.msc3026.busy"
+    OFFLINE: Final = "offline"
+    UNAVAILABLE: Final = "unavailable"
+    ONLINE: Final = "online"
+    BUSY: Final = "org.matrix.msc3026.busy"
 
 
 class JoinRules:
-    PUBLIC = "public"
-    KNOCK = "knock"
-    INVITE = "invite"
-    PRIVATE = "private"
+    PUBLIC: Final = "public"
+    KNOCK: Final = "knock"
+    INVITE: Final = "invite"
+    PRIVATE: Final = "private"
     # As defined for MSC3083.
-    RESTRICTED = "restricted"
+    RESTRICTED: Final = "restricted"
 
 
 class RestrictedJoinRuleTypes:
     """Understood types for the allow rules in restricted join rules."""
 
-    ROOM_MEMBERSHIP = "m.room_membership"
+    ROOM_MEMBERSHIP: Final = "m.room_membership"
 
 
 class LoginType:
-    PASSWORD = "m.login.password"
-    EMAIL_IDENTITY = "m.login.email.identity"
-    MSISDN = "m.login.msisdn"
-    RECAPTCHA = "m.login.recaptcha"
-    TERMS = "m.login.terms"
-    SSO = "m.login.sso"
-    DUMMY = "m.login.dummy"
-    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
+    PASSWORD: Final = "m.login.password"
+    EMAIL_IDENTITY: Final = "m.login.email.identity"
+    MSISDN: Final = "m.login.msisdn"
+    RECAPTCHA: Final = "m.login.recaptcha"
+    TERMS: Final = "m.login.terms"
+    SSO: Final = "m.login.sso"
+    DUMMY: Final = "m.login.dummy"
+    REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token"
 
 
 # This is used in the `type` parameter for /register when called by
 # an appservice to register a new user.
-APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service"
+APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service"
 
 
 class EventTypes:
-    Member = "m.room.member"
-    Create = "m.room.create"
-    Tombstone = "m.room.tombstone"
-    JoinRules = "m.room.join_rules"
-    PowerLevels = "m.room.power_levels"
-    Aliases = "m.room.aliases"
-    Redaction = "m.room.redaction"
-    ThirdPartyInvite = "m.room.third_party_invite"
-    RelatedGroups = "m.room.related_groups"
-
-    RoomHistoryVisibility = "m.room.history_visibility"
-    CanonicalAlias = "m.room.canonical_alias"
-    Encrypted = "m.room.encrypted"
-    RoomAvatar = "m.room.avatar"
-    RoomEncryption = "m.room.encryption"
-    GuestAccess = "m.room.guest_access"
+    Member: Final = "m.room.member"
+    Create: Final = "m.room.create"
+    Tombstone: Final = "m.room.tombstone"
+    JoinRules: Final = "m.room.join_rules"
+    PowerLevels: Final = "m.room.power_levels"
+    Aliases: Final = "m.room.aliases"
+    Redaction: Final = "m.room.redaction"
+    ThirdPartyInvite: Final = "m.room.third_party_invite"
+    RelatedGroups: Final = "m.room.related_groups"
+
+    RoomHistoryVisibility: Final = "m.room.history_visibility"
+    CanonicalAlias: Final = "m.room.canonical_alias"
+    Encrypted: Final = "m.room.encrypted"
+    RoomAvatar: Final = "m.room.avatar"
+    RoomEncryption: Final = "m.room.encryption"
+    GuestAccess: Final = "m.room.guest_access"
 
     # These are used for validation
-    Message = "m.room.message"
-    Topic = "m.room.topic"
-    Name = "m.room.name"
+    Message: Final = "m.room.message"
+    Topic: Final = "m.room.topic"
+    Name: Final = "m.room.name"
 
-    ServerACL = "m.room.server_acl"
-    Pinned = "m.room.pinned_events"
+    ServerACL: Final = "m.room.server_acl"
+    Pinned: Final = "m.room.pinned_events"
 
-    Retention = "m.room.retention"
+    Retention: Final = "m.room.retention"
 
-    Dummy = "org.matrix.dummy_event"
+    Dummy: Final = "org.matrix.dummy_event"
 
-    SpaceChild = "m.space.child"
-    SpaceParent = "m.space.parent"
+    SpaceChild: Final = "m.space.child"
+    SpaceParent: Final = "m.space.parent"
 
-    MSC2716_INSERTION = "org.matrix.msc2716.insertion"
-    MSC2716_BATCH = "org.matrix.msc2716.batch"
-    MSC2716_MARKER = "org.matrix.msc2716.marker"
+    MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion"
+    MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
+    MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
 
 
 class ToDeviceEventTypes:
-    RoomKeyRequest = "m.room_key_request"
+    RoomKeyRequest: Final = "m.room_key_request"
 
 
 class DeviceKeyAlgorithms:
     """Spec'd algorithms for the generation of per-device keys"""
 
-    ED25519 = "ed25519"
-    CURVE25519 = "curve25519"
-    SIGNED_CURVE25519 = "signed_curve25519"
+    ED25519: Final = "ed25519"
+    CURVE25519: Final = "curve25519"
+    SIGNED_CURVE25519: Final = "signed_curve25519"
 
 
 class EduTypes:
-    Presence = "m.presence"
+    Presence: Final = "m.presence"
 
 
 class RejectedReason:
-    AUTH_ERROR = "auth_error"
+    AUTH_ERROR: Final = "auth_error"
 
 
 class RoomCreationPreset:
-    PRIVATE_CHAT = "private_chat"
-    PUBLIC_CHAT = "public_chat"
-    TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
+    PRIVATE_CHAT: Final = "private_chat"
+    PUBLIC_CHAT: Final = "public_chat"
+    TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat"
 
 
 class ThirdPartyEntityKind:
-    USER = "user"
-    LOCATION = "location"
+    USER: Final = "user"
+    LOCATION: Final = "location"
 
 
-ServerNoticeMsgType = "m.server_notice"
-ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
+ServerNoticeMsgType: Final = "m.server_notice"
+ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached"
 
 
 class UserTypes:
@@ -165,91 +167,91 @@ class UserTypes:
     'admin' and 'guest' users should also be UserTypes. Normal users are type None
     """
 
-    SUPPORT = "support"
-    BOT = "bot"
-    ALL_USER_TYPES = (SUPPORT, BOT)
+    SUPPORT: Final = "support"
+    BOT: Final = "bot"
+    ALL_USER_TYPES: Final = (SUPPORT, BOT)
 
 
 class RelationTypes:
     """The types of relations known to this server."""
 
-    ANNOTATION = "m.annotation"
-    REPLACE = "m.replace"
-    REFERENCE = "m.reference"
-    THREAD = "io.element.thread"
+    ANNOTATION: Final = "m.annotation"
+    REPLACE: Final = "m.replace"
+    REFERENCE: Final = "m.reference"
+    THREAD: Final = "io.element.thread"
 
 
 class LimitBlockingTypes:
     """Reasons that a server may be blocked"""
 
-    MONTHLY_ACTIVE_USER = "monthly_active_user"
-    HS_DISABLED = "hs_disabled"
+    MONTHLY_ACTIVE_USER: Final = "monthly_active_user"
+    HS_DISABLED: Final = "hs_disabled"
 
 
 class EventContentFields:
     """Fields found in events' content, regardless of type."""
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
-    LABELS = "org.matrix.labels"
+    LABELS: Final = "org.matrix.labels"
 
     # Timestamp to delete the event after
     # cf https://github.com/matrix-org/matrix-doc/pull/2228
-    SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+    SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after"
 
     # cf https://github.com/matrix-org/matrix-doc/pull/1772
-    ROOM_TYPE = "type"
+    ROOM_TYPE: Final = "type"
 
     # Whether a room can federate.
-    FEDERATE = "m.federate"
+    FEDERATE: Final = "m.federate"
 
     # The creator of the room, as used in `m.room.create` events.
-    ROOM_CREATOR = "creator"
+    ROOM_CREATOR: Final = "creator"
 
     # Used in m.room.guest_access events.
-    GUEST_ACCESS = "guest_access"
+    GUEST_ACCESS: Final = "guest_access"
 
     # Used on normal messages to indicate they were historically imported after the fact
-    MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
+    MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
     # For "insertion" events to indicate what the next batch ID should be in
     # order to connect to it
-    MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
+    MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
     # Used on "batch" events to indicate which insertion event it connects to
-    MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
+    MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
     # For "marker" events
-    MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
+    MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
 
     # The authorising user for joining a restricted room.
-    AUTHORISING_USER = "join_authorised_via_users_server"
+    AUTHORISING_USER: Final = "join_authorised_via_users_server"
 
 
 class RoomTypes:
     """Understood values of the room_type field of m.room.create events."""
 
-    SPACE = "m.space"
+    SPACE: Final = "m.space"
 
 
 class RoomEncryptionAlgorithms:
-    MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
-    DEFAULT = MEGOLM_V1_AES_SHA2
+    MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2"
+    DEFAULT: Final = MEGOLM_V1_AES_SHA2
 
 
 class AccountDataTypes:
-    DIRECT = "m.direct"
-    IGNORED_USER_LIST = "m.ignored_user_list"
+    DIRECT: Final = "m.direct"
+    IGNORED_USER_LIST: Final = "m.ignored_user_list"
 
 
 class HistoryVisibility:
-    INVITED = "invited"
-    JOINED = "joined"
-    SHARED = "shared"
-    WORLD_READABLE = "world_readable"
+    INVITED: Final = "invited"
+    JOINED: Final = "joined"
+    SHARED: Final = "shared"
+    WORLD_READABLE: Final = "world_readable"
 
 
 class GuestAccess:
-    CAN_JOIN = "can_join"
+    CAN_JOIN: Final = "can_join"
     # anything that is not "can_join" is considered "forbidden", but for completeness:
-    FORBIDDEN = "forbidden"
+    FORBIDDEN: Final = "forbidden"
 
 
 class ReadReceiptEventFields:
-    MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
+    MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 4486b3bc7d..f9f9467dc1 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -30,7 +30,8 @@ FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
 STATIC_PREFIX = "/_matrix/static"
 WEB_CLIENT_PREFIX = "/_matrix/client"
 SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
-MEDIA_PREFIX = "/_matrix/media/r0"
+MEDIA_R0_PREFIX = "/_matrix/media/r0"
+MEDIA_V3_PREFIX = "/_matrix/media/v3"
 LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
 
 
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 573bb487b2..5fc59c1be1 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -32,6 +32,7 @@ from typing import (
     Iterable,
     List,
     NoReturn,
+    Optional,
     Tuple,
     cast,
 )
@@ -129,7 +130,7 @@ def start_worker_reactor(
 def start_reactor(
     appname: str,
     soft_file_limit: int,
-    gc_thresholds: Tuple[int, int, int],
+    gc_thresholds: Optional[Tuple[int, int, int]],
     pid_file: str,
     daemonize: bool,
     print_pidfile: bool,
@@ -402,7 +403,7 @@ async def start(hs: "HomeServer") -> None:
     if hasattr(signal, "SIGHUP"):
 
         @wrap_as_background_process("sighup")
-        def handle_sighup(*args: Any, **kwargs: Any) -> None:
+        async def handle_sighup(*args: Any, **kwargs: Any) -> None:
             # Tell systemd our state, if we're using it. This will silently fail if
             # we're not using systemd.
             sdnotify(b"RELOADING=1")
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 46f0feff70..e256de2003 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -26,7 +26,8 @@ from synapse.api.urls import (
     CLIENT_API_PREFIX,
     FEDERATION_PREFIX,
     LEGACY_MEDIA_PREFIX,
-    MEDIA_PREFIX,
+    MEDIA_R0_PREFIX,
+    MEDIA_V3_PREFIX,
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
@@ -112,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import (
 )
 from synapse.storage.databases.main.presence import PresenceStore
 from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.room_batch import RoomBatchStore
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.session import SessionStore
 from synapse.storage.databases.main.stats import StatsStore
@@ -239,6 +241,7 @@ class GenericWorkerSlavedStore(
     SlavedEventStore,
     SlavedKeyStore,
     RoomWorkerStore,
+    RoomBatchStore,
     DirectoryStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
@@ -338,7 +341,8 @@ class GenericWorkerServer(HomeServer):
 
                         resources.update(
                             {
-                                MEDIA_PREFIX: media_repo,
+                                MEDIA_R0_PREFIX: media_repo,
+                                MEDIA_V3_PREFIX: media_repo,
                                 LEGACY_MEDIA_PREFIX: media_repo,
                                 "/_synapse/admin": admin_resource,
                             }
@@ -501,6 +505,10 @@ def start(config_options: List[str]) -> None:
     _base.start_worker_reactor("synapse-generic-worker", config)
 
 
-if __name__ == "__main__":
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7bb3744f04..dd76e07321 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -29,7 +29,8 @@ from synapse import events
 from synapse.api.urls import (
     FEDERATION_PREFIX,
     LEGACY_MEDIA_PREFIX,
-    MEDIA_PREFIX,
+    MEDIA_R0_PREFIX,
+    MEDIA_V3_PREFIX,
     SERVER_KEY_V2_PREFIX,
     STATIC_PREFIX,
     WEB_CLIENT_PREFIX,
@@ -193,6 +194,8 @@ class SynapseHomeServer(HomeServer):
                 {
                     "/_matrix/client/api/v1": client_resource,
                     "/_matrix/client/r0": client_resource,
+                    "/_matrix/client/v1": client_resource,
+                    "/_matrix/client/v3": client_resource,
                     "/_matrix/client/unstable": client_resource,
                     "/_matrix/client/v2_alpha": client_resource,
                     "/_matrix/client/versions": client_resource,
@@ -244,7 +247,11 @@ class SynapseHomeServer(HomeServer):
             if self.config.server.enable_media_repo:
                 media_repo = self.get_media_repository_resource()
                 resources.update(
-                    {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
+                    {
+                        MEDIA_R0_PREFIX: media_repo,
+                        MEDIA_V3_PREFIX: media_repo,
+                        LEGACY_MEDIA_PREFIX: media_repo,
+                    }
                 )
             elif name == "media":
                 raise ConfigError(
@@ -351,6 +358,13 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
         # generating config files and shouldn't try to continue.
         sys.exit(0)
 
+    if config.worker.worker_app:
+        raise ConfigError(
+            "You have specified `worker_app` in the config but are attempting to start a non-worker "
+            "instance. Please use `python -m synapse.app.generic_worker` instead (or remove the option if this is the main process)."
+        )
+        sys.exit(1)
+
     events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
     synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
 
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 0ca8b2ae40..e33e69eed1 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 import logging
 import re
+from enum import Enum
 from typing import TYPE_CHECKING, Iterable, List, Match, Optional
 
 from synapse.api.constants import EventTypes
@@ -27,7 +28,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class ApplicationServiceState:
+class ApplicationServiceState(Enum):
     DOWN = "down"
     UP = "up"
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index a54b4e867d..ca58f92339 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -249,13 +249,32 @@ class ApplicationServiceApi(SimpleHttpClient):
                 json_body=body,
                 args={"access_token": service.hs_token},
             )
+            if logger.isEnabledFor(logging.DEBUG):
+                logger.debug(
+                    "push_bulk to %s succeeded! events=%s",
+                    uri,
+                    [event.get("event_id") for event in events],
+                )
             sent_transactions_counter.labels(service.id).inc()
             sent_events_counter.labels(service.id).inc(len(events))
             return True
         except CodeMessageException as e:
-            logger.warning("push_bulk to %s received %s", uri, e.code)
+            logger.warning(
+                "push_bulk to %s received code=%s msg=%s",
+                uri,
+                e.code,
+                e.msg,
+                exc_info=logger.isEnabledFor(logging.DEBUG),
+            )
         except Exception as ex:
-            logger.warning("push_bulk to %s threw exception %s", uri, ex)
+            logger.warning(
+                "push_bulk to %s threw exception(%s) %s args=%s",
+                uri,
+                type(ex).__name__,
+                ex,
+                ex.args,
+                exc_info=logger.isEnabledFor(logging.DEBUG),
+            )
         failed_transactions_counter.labels(service.id).inc()
         return False
 
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index c555f5f914..b2a7a89a35 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -13,12 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import sys
+from typing import List
 
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 
 
-def main(args):
+def main(args: List[str]) -> None:
     action = args[1] if len(args) > 1 and args[1] == "read" else None
     # If we're reading a key in the config file, then `args[1]` will be `read`  and `args[2]`
     # will be the key to read.
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7c4428a138..1265738dc1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,7 +20,18 @@ import os
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Iterable, List, MutableMapping, Optional, Union
+from typing import (
+    Any,
+    Dict,
+    Iterable,
+    List,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 
 import attr
 import jinja2
@@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\
 """
 
 
-def path_exists(file_path):
+def path_exists(file_path: str) -> bool:
     """Check if a file exists
 
     Unlike os.path.exists, this throws an exception if there is an error
@@ -86,7 +97,7 @@ def path_exists(file_path):
     the parent dir).
 
     Returns:
-        bool: True if the file exists; False if not.
+        True if the file exists; False if not.
     """
     try:
         os.stat(file_path)
@@ -102,15 +113,15 @@ class Config:
     A configuration section, containing configuration keys and values.
 
     Attributes:
-        section (str): The section title of this config object, such as
+        section: The section title of this config object, such as
             "tls" or "logger". This is used to refer to it on the root
             logger (for example, `config.tls.some_option`). Must be
             defined in subclasses.
     """
 
-    section = None
+    section: str
 
-    def __init__(self, root_config=None):
+    def __init__(self, root_config: "RootConfig" = None):
         self.root = root_config
 
         # Get the path to the default Synapse template directory
@@ -119,7 +130,7 @@ class Config:
         )
 
     @staticmethod
-    def parse_size(value):
+    def parse_size(value: Union[str, int]) -> int:
         if isinstance(value, int):
             return value
         sizes = {"K": 1024, "M": 1024 * 1024}
@@ -162,15 +173,15 @@ class Config:
         return int(value) * size
 
     @staticmethod
-    def abspath(file_path):
+    def abspath(file_path: str) -> str:
         return os.path.abspath(file_path) if file_path else file_path
 
     @classmethod
-    def path_exists(cls, file_path):
+    def path_exists(cls, file_path: str) -> bool:
         return path_exists(file_path)
 
     @classmethod
-    def check_file(cls, file_path, config_name):
+    def check_file(cls, file_path: Optional[str], config_name: str) -> str:
         if file_path is None:
             raise ConfigError("Missing config for %s." % (config_name,))
         try:
@@ -183,7 +194,7 @@ class Config:
         return cls.abspath(file_path)
 
     @classmethod
-    def ensure_directory(cls, dir_path):
+    def ensure_directory(cls, dir_path: str) -> str:
         dir_path = cls.abspath(dir_path)
         os.makedirs(dir_path, exist_ok=True)
         if not os.path.isdir(dir_path):
@@ -191,7 +202,7 @@ class Config:
         return dir_path
 
     @classmethod
-    def read_file(cls, file_path, config_name):
+    def read_file(cls, file_path: Any, config_name: str) -> str:
         """Deprecated: call read_file directly"""
         return read_file(file_path, (config_name,))
 
@@ -284,6 +295,9 @@ class Config:
         return [env.get_template(filename) for filename in filenames]
 
 
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
+
+
 class RootConfig:
     """
     Holder of an application's configuration.
@@ -308,7 +322,9 @@ class RootConfig:
                 raise Exception("Failed making %s: %r" % (config_class.section, e))
             setattr(self, config_class.section, conf)
 
-    def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+    def invoke_all(
+        self, func_name: str, *args: Any, **kwargs: Any
+    ) -> MutableMapping[str, Any]:
         """
         Invoke a function on all instantiated config objects this RootConfig is
         configured to use.
@@ -317,6 +333,7 @@ class RootConfig:
             func_name: Name of function to invoke
             *args
             **kwargs
+
         Returns:
             ordered dictionary of config section name and the result of the
             function from it.
@@ -332,7 +349,7 @@ class RootConfig:
         return res
 
     @classmethod
-    def invoke_all_static(cls, func_name: str, *args, **kwargs):
+    def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None:
         """
         Invoke a static function on config objects this RootConfig is
         configured to use.
@@ -341,6 +358,7 @@ class RootConfig:
             func_name: Name of function to invoke
             *args
             **kwargs
+
         Returns:
             ordered dictionary of config section name and the result of the
             function from it.
@@ -351,16 +369,16 @@ class RootConfig:
 
     def generate_config(
         self,
-        config_dir_path,
-        data_dir_path,
-        server_name,
-        generate_secrets=False,
-        report_stats=None,
-        open_private_ports=False,
-        listeners=None,
-        tls_certificate_path=None,
-        tls_private_key_path=None,
-    ):
+        config_dir_path: str,
+        data_dir_path: str,
+        server_name: str,
+        generate_secrets: bool = False,
+        report_stats: Optional[bool] = None,
+        open_private_ports: bool = False,
+        listeners: Optional[List[dict]] = None,
+        tls_certificate_path: Optional[str] = None,
+        tls_private_key_path: Optional[str] = None,
+    ) -> str:
         """
         Build a default configuration file
 
@@ -368,27 +386,27 @@ class RootConfig:
         (eg with --generate_config).
 
         Args:
-            config_dir_path (str): The path where the config files are kept. Used to
+            config_dir_path: The path where the config files are kept. Used to
                 create filenames for things like the log config and the signing key.
 
-            data_dir_path (str): The path where the data files are kept. Used to create
+            data_dir_path: The path where the data files are kept. Used to create
                 filenames for things like the database and media store.
 
-            server_name (str): The server name. Used to initialise the server_name
+            server_name: The server name. Used to initialise the server_name
                 config param, but also used in the names of some of the config files.
 
-            generate_secrets (bool): True if we should generate new secrets for things
+            generate_secrets: True if we should generate new secrets for things
                 like the macaroon_secret_key. If False, these parameters will be left
                 unset.
 
-            report_stats (bool|None): Initial setting for the report_stats setting.
+            report_stats: Initial setting for the report_stats setting.
                 If None, report_stats will be left unset.
 
-            open_private_ports (bool): True to leave private ports (such as the non-TLS
+            open_private_ports: True to leave private ports (such as the non-TLS
                 HTTP listener) open to the internet.
 
-            listeners (list(dict)|None): A list of descriptions of the listeners
-                synapse should start with each of which specifies a port (str), a list of
+            listeners: A list of descriptions of the listeners synapse should
+                start with each of which specifies a port (int), a list of
                 resources (list(str)), tls (bool) and type (str). For example:
                 [{
                     "port": 8448,
@@ -403,16 +421,12 @@ class RootConfig:
                     "type": "http",
                 }],
 
+            tls_certificate_path: The path to the tls certificate.
 
-            database (str|None): The database type to configure, either `psycog2`
-                or `sqlite3`.
-
-            tls_certificate_path (str|None): The path to the tls certificate.
-
-            tls_private_key_path (str|None): The path to the tls private key.
+            tls_private_key_path: The path to the tls private key.
 
         Returns:
-            str: the yaml config file
+            The yaml config file
         """
 
         return CONFIG_FILE_HEADER + "\n\n".join(
@@ -432,12 +446,15 @@ class RootConfig:
         )
 
     @classmethod
-    def load_config(cls, description, argv):
+    def load_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> TRootConfig:
         """Parse the commandline and config files
 
         Doesn't support config-file-generation: used by the worker apps.
 
-        Returns: Config object.
+        Returns:
+            Config object.
         """
         config_parser = argparse.ArgumentParser(description=description)
         cls.add_arguments_to_parser(config_parser)
@@ -446,7 +463,7 @@ class RootConfig:
         return obj
 
     @classmethod
-    def add_arguments_to_parser(cls, config_parser):
+    def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None:
         """Adds all the config flags to an ArgumentParser.
 
         Doesn't support config-file-generation: used by the worker apps.
@@ -454,7 +471,7 @@ class RootConfig:
         Used for workers where we want to add extra flags/subcommands.
 
         Args:
-            config_parser (ArgumentParser): App description
+            config_parser: App description
         """
 
         config_parser.add_argument(
@@ -477,7 +494,9 @@ class RootConfig:
         cls.invoke_all_static("add_arguments", config_parser)
 
     @classmethod
-    def load_config_with_parser(cls, parser, argv):
+    def load_config_with_parser(
+        cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+    ) -> Tuple[TRootConfig, argparse.Namespace]:
         """Parse the commandline and config files with the given parser
 
         Doesn't support config-file-generation: used by the worker apps.
@@ -485,13 +504,12 @@ class RootConfig:
         Used for workers where we want to add extra flags/subcommands.
 
         Args:
-            parser (ArgumentParser)
-            argv (list[str])
+            parser
+            argv
 
         Returns:
-            tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
-            config object and the parsed argparse.Namespace object from
-            `parser.parse_args(..)`
+            Returns the parsed config object and the parsed argparse.Namespace
+            object from parser.parse_args(..)`
         """
 
         obj = cls()
@@ -520,12 +538,15 @@ class RootConfig:
         return obj, config_args
 
     @classmethod
-    def load_or_generate_config(cls, description, argv):
+    def load_or_generate_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> Optional[TRootConfig]:
         """Parse the commandline and config files
 
         Supports generation of config files, so is used for the main homeserver app.
 
-        Returns: Config object, or None if --generate-config or --generate-keys was set
+        Returns:
+            Config object, or None if --generate-config or --generate-keys was set
         """
         parser = argparse.ArgumentParser(description=description)
         parser.add_argument(
@@ -680,16 +701,21 @@ class RootConfig:
 
         return obj
 
-    def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
+    def parse_config_dict(
+        self,
+        config_dict: Dict[str, Any],
+        config_dir_path: Optional[str] = None,
+        data_dir_path: Optional[str] = None,
+    ) -> None:
         """Read the information from the config dict into this Config object.
 
         Args:
-            config_dict (dict): Configuration data, as read from the yaml
+            config_dict: Configuration data, as read from the yaml
 
-            config_dir_path (str): The path where the config files are kept. Used to
+            config_dir_path: The path where the config files are kept. Used to
                 create filenames for things like the log config and the signing key.
 
-            data_dir_path (str): The path where the data files are kept. Used to create
+            data_dir_path: The path where the data files are kept. Used to create
                 filenames for things like the database and media store.
         """
         self.invoke_all(
@@ -699,17 +725,20 @@ class RootConfig:
             data_dir_path=data_dir_path,
         )
 
-    def generate_missing_files(self, config_dict, config_dir_path):
+    def generate_missing_files(
+        self, config_dict: Dict[str, Any], config_dir_path: str
+    ) -> None:
         self.invoke_all("generate_files", config_dict, config_dir_path)
 
 
-def read_config_files(config_files):
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
     """Read the config files into a dict
 
     Args:
-        config_files (iterable[str]): A list of the config files to read
+        config_files: A list of the config files to read
 
-    Returns: dict
+    Returns:
+        The configuration dictionary.
     """
     specified_config = {}
     for config_file in config_files:
@@ -733,17 +762,17 @@ def read_config_files(config_files):
     return specified_config
 
 
-def find_config_files(search_paths):
+def find_config_files(search_paths: List[str]) -> List[str]:
     """Finds config files using a list of search paths. If a path is a file
     then that file path is added to the list. If a search path is a directory
     then all the "*.yaml" files in that directory are added to the list in
     sorted order.
 
     Args:
-        search_paths(list(str)): A list of paths to search.
+        search_paths: A list of paths to search.
 
     Returns:
-        list(str): A list of file paths.
+        A list of file paths.
     """
 
     config_files = []
@@ -777,7 +806,7 @@ def find_config_files(search_paths):
     return config_files
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class ShardedWorkerHandlingConfig:
     """Algorithm for choosing which instance is responsible for handling some
     sharded work.
@@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig:
     below).
     """
 
-    instances = attr.ib(type=List[str])
+    instances: List[str]
 
     def should_handle(self, instance_name: str, key: str) -> bool:
         """Whether this instance is responsible for handling the given key."""
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index c1d9069798..1eb5f5a68c 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,4 +1,18 @@
-from typing import Any, Iterable, List, Optional
+import argparse
+from typing import (
+    Any,
+    Dict,
+    Iterable,
+    List,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
+
+import jinja2
 
 from synapse.config import (
     account_validity,
@@ -19,6 +33,7 @@ from synapse.config import (
     logger,
     metrics,
     modules,
+    oembed,
     oidc,
     password_auth_providers,
     push,
@@ -27,6 +42,7 @@ from synapse.config import (
     registration,
     repository,
     retention,
+    room,
     room_directory,
     saml2,
     server,
@@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
 MISSING_REPORT_STATS_SPIEL: str
 MISSING_SERVER_NAME: str
 
-def path_exists(file_path: str): ...
+def path_exists(file_path: str) -> bool: ...
+
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
 
 class RootConfig:
     server: server.ServerConfig
@@ -61,6 +79,7 @@ class RootConfig:
     logging: logger.LoggingConfig
     ratelimiting: ratelimiting.RatelimitConfig
     media: repository.ContentRepositoryConfig
+    oembed: oembed.OembedConfig
     captcha: captcha.CaptchaConfig
     voip: voip.VoipConfig
     registration: registration.RegistrationConfig
@@ -80,6 +99,7 @@ class RootConfig:
     authproviders: password_auth_providers.PasswordAuthProviderConfig
     push: push.PushConfig
     spamchecker: spam_checker.SpamCheckerConfig
+    room: room.RoomConfig
     groups: groups.GroupsConfig
     userdirectory: user_directory.UserDirectoryConfig
     consent: consent.ConsentConfig
@@ -87,72 +107,85 @@ class RootConfig:
     servernotices: server_notices.ServerNoticesConfig
     roomdirectory: room_directory.RoomDirectoryConfig
     thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
-    tracer: tracer.TracerConfig
+    tracing: tracer.TracerConfig
     redis: redis.RedisConfig
     modules: modules.ModulesConfig
     caches: cache.CacheConfig
     federation: federation.FederationConfig
     retention: retention.RetentionConfig
 
-    config_classes: List = ...
+    config_classes: List[Type["Config"]] = ...
     def __init__(self) -> None: ...
-    def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+    def invoke_all(
+        self, func_name: str, *args: Any, **kwargs: Any
+    ) -> MutableMapping[str, Any]: ...
     @classmethod
     def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
-    def __getattr__(self, item: str): ...
     def parse_config_dict(
         self,
-        config_dict: Any,
-        config_dir_path: Optional[Any] = ...,
-        data_dir_path: Optional[Any] = ...,
+        config_dict: Dict[str, Any],
+        config_dir_path: Optional[str] = ...,
+        data_dir_path: Optional[str] = ...,
     ) -> None: ...
-    read_config: Any = ...
     def generate_config(
         self,
         config_dir_path: str,
         data_dir_path: str,
         server_name: str,
         generate_secrets: bool = ...,
-        report_stats: Optional[str] = ...,
+        report_stats: Optional[bool] = ...,
         open_private_ports: bool = ...,
         listeners: Optional[Any] = ...,
-        database_conf: Optional[Any] = ...,
         tls_certificate_path: Optional[str] = ...,
         tls_private_key_path: Optional[str] = ...,
-    ): ...
+    ) -> str: ...
     @classmethod
-    def load_or_generate_config(cls, description: Any, argv: Any): ...
+    def load_or_generate_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> Optional[TRootConfig]: ...
     @classmethod
-    def load_config(cls, description: Any, argv: Any): ...
+    def load_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> TRootConfig: ...
     @classmethod
-    def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+    def add_arguments_to_parser(
+        cls, config_parser: argparse.ArgumentParser
+    ) -> None: ...
     @classmethod
-    def load_config_with_parser(cls, parser: Any, argv: Any): ...
+    def load_config_with_parser(
+        cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+    ) -> Tuple[TRootConfig, argparse.Namespace]: ...
     def generate_missing_files(
         self, config_dict: dict, config_dir_path: str
     ) -> None: ...
 
 class Config:
     root: RootConfig
+    default_template_dir: str
     def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
-    def __getattr__(self, item: str, from_root: bool = ...): ...
     @staticmethod
-    def parse_size(value: Any): ...
+    def parse_size(value: Union[str, int]) -> int: ...
     @staticmethod
-    def parse_duration(value: Any): ...
+    def parse_duration(value: Union[str, int]) -> int: ...
     @staticmethod
-    def abspath(file_path: Optional[str]): ...
+    def abspath(file_path: Optional[str]) -> str: ...
     @classmethod
-    def path_exists(cls, file_path: str): ...
+    def path_exists(cls, file_path: str) -> bool: ...
     @classmethod
-    def check_file(cls, file_path: str, config_name: str): ...
+    def check_file(cls, file_path: str, config_name: str) -> str: ...
     @classmethod
-    def ensure_directory(cls, dir_path: str): ...
+    def ensure_directory(cls, dir_path: str) -> str: ...
     @classmethod
-    def read_file(cls, file_path: str, config_name: str): ...
+    def read_file(cls, file_path: str, config_name: str) -> str: ...
+    def read_template(self, filenames: str) -> jinja2.Template: ...
+    def read_templates(
+        self,
+        filenames: List[str],
+        custom_template_directories: Optional[Iterable[str]] = None,
+    ) -> List[jinja2.Template]: ...
 
-def read_config_files(config_files: List[str]): ...
-def find_config_files(search_paths: List[str]): ...
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ...
+def find_config_files(search_paths: List[str]) -> List[str]: ...
 
 class ShardedWorkerHandlingConfig:
     instances: List[str]
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 1ebea88db2..e4bb7224a4 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -1,4 +1,5 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,14 +14,14 @@
 # limitations under the License.
 
 import logging
-from typing import Dict
+from typing import Dict, List
 from urllib import parse as urlparse
 
 import yaml
 from netaddr import IPSet
 
 from synapse.appservice import ApplicationService
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
 
 from ._base import Config, ConfigError
 
@@ -30,12 +31,12 @@ logger = logging.getLogger(__name__)
 class AppServiceConfig(Config):
     section = "appservice"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         self.app_service_config_files = config.get("app_service_config_files", [])
         self.notify_appservices = config.get("notify_appservices", True)
         self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
 
-    def generate_config_section(cls, **kwargs):
+    def generate_config_section(cls, **kwargs) -> str:
         return """\
         # A list of application service config files to use
         #
@@ -50,7 +51,9 @@ class AppServiceConfig(Config):
         """
 
 
-def load_appservices(hostname, config_files):
+def load_appservices(
+    hostname: str, config_files: List[str]
+) -> List[ApplicationService]:
     """Returns a list of Application Services from the config files."""
     if not isinstance(config_files, list):
         logger.warning("Expected %s to be a list of AS config files.", config_files)
@@ -93,7 +96,9 @@ def load_appservices(hostname, config_files):
     return appservices
 
 
-def _load_appservice(hostname, as_info, config_filename):
+def _load_appservice(
+    hostname: str, as_info: JsonDict, config_filename: str
+) -> ApplicationService:
     required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
     for field in required_string_fields:
         if not isinstance(as_info.get(field), str):
@@ -115,9 +120,9 @@ def _load_appservice(hostname, as_info, config_filename):
     user_id = user.to_string()
 
     # Rate limiting for users of this AS is on by default (excludes sender)
-    rate_limited = True
-    if isinstance(as_info.get("rate_limited"), bool):
-        rate_limited = as_info.get("rate_limited")
+    rate_limited = as_info.get("rate_limited")
+    if not isinstance(rate_limited, bool):
+        rate_limited = True
 
     # namespace checks
     if not isinstance(as_info.get("namespaces"), dict):
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index d119427ad8..d9d85f98e1 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -1,4 +1,4 @@
-# Copyright 2019 Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,7 +15,9 @@
 import os
 import re
 import threading
-from typing import Callable, Dict
+from typing import Callable, Dict, Optional
+
+import attr
 
 from synapse.python_dependencies import DependencyException, check_requirements
 
@@ -34,13 +36,13 @@ _DEFAULT_FACTOR_SIZE = 0.5
 _DEFAULT_EVENT_CACHE_SIZE = "10K"
 
 
+@attr.s(slots=True, auto_attribs=True)
 class CacheProperties:
-    def __init__(self):
-        # The default factor size for all caches
-        self.default_factor_size = float(
-            os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
-        )
-        self.resize_all_caches_func = None
+    # The default factor size for all caches
+    default_factor_size: float = float(
+        os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
+    )
+    resize_all_caches_func: Optional[Callable[[], None]] = None
 
 
 properties = CacheProperties()
@@ -62,7 +64,7 @@ def _canonicalise_cache_name(cache_name: str) -> str:
 
 def add_resizable_cache(
     cache_name: str, cache_resize_callback: Callable[[float], None]
-):
+) -> None:
     """Register a cache that's size can dynamically change
 
     Args:
@@ -91,7 +93,7 @@ class CacheConfig(Config):
     _environ = os.environ
 
     @staticmethod
-    def reset():
+    def reset() -> None:
         """Resets the caches to their defaults. Used for tests."""
         properties.default_factor_size = float(
             os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
@@ -100,7 +102,7 @@ class CacheConfig(Config):
         with _CACHES_LOCK:
             _CACHES.clear()
 
-    def generate_config_section(self, **kwargs):
+    def generate_config_section(self, **kwargs) -> str:
         return """\
         ## Caching ##
 
@@ -162,7 +164,7 @@ class CacheConfig(Config):
           #sync_response_cache_duration: 2m
         """
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         self.event_cache_size = self.parse_size(
             config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
         )
@@ -217,7 +219,7 @@ class CacheConfig(Config):
 
         expiry_time = cache_config.get("expiry_time")
         if expiry_time:
-            self.expiry_time_msec = self.parse_duration(expiry_time)
+            self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time)
         else:
             self.expiry_time_msec = None
 
@@ -232,7 +234,7 @@ class CacheConfig(Config):
         # needing an instance of Config
         properties.resize_all_caches_func = self.resize_all_caches
 
-    def resize_all_caches(self):
+    def resize_all_caches(self) -> None:
         """Ensure all cache sizes are up to date
 
         For each cache, run the mapped callback function with either
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 3f81814043..6f2754092e 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -1,4 +1,5 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,7 +29,7 @@ class CasConfig(Config):
 
     section = "cas"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         cas_config = config.get("cas_config", None)
         self.cas_enabled = cas_config and cas_config.get("enabled", True)
 
@@ -51,7 +52,7 @@ class CasConfig(Config):
             self.cas_displayname_attribute = None
             self.cas_required_attributes = []
 
-    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+    def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
         return """\
         # Enable Central Authentication Service (CAS) for registration and login.
         #
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 651e31b576..06ccf15cd9 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -1,5 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import argparse
 import logging
 import os
 
@@ -119,7 +120,7 @@ class DatabaseConfig(Config):
 
         self.databases = []
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         # We *experimentally* support specifying multiple databases via the
         # `databases` key. This is a map from a label to database config in the
         # same format as the `database` config option, plus an extra
@@ -163,12 +164,12 @@ class DatabaseConfig(Config):
             self.databases = [DatabaseConnectionConfig("master", database_config)]
             self.set_databasepath(database_path)
 
-    def generate_config_section(self, data_dir_path, **kwargs):
+    def generate_config_section(self, data_dir_path, **kwargs) -> str:
         return DEFAULT_CONFIG % {
             "database_path": os.path.join(data_dir_path, "homeserver.db")
         }
 
-    def read_arguments(self, args):
+    def read_arguments(self, args: argparse.Namespace) -> None:
         """
         Cases for the cli input:
           - If no databases are configured and no database_path is set, raise.
@@ -194,7 +195,7 @@ class DatabaseConfig(Config):
         else:
             logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
 
-    def set_databasepath(self, database_path):
+    def set_databasepath(self, database_path: str) -> None:
 
         if database_path != ":memory:":
             database_path = self.abspath(database_path)
@@ -202,7 +203,7 @@ class DatabaseConfig(Config):
         self.databases[0].config["args"]["database"] = database_path
 
     @staticmethod
-    def add_arguments(parser):
+    def add_arguments(parser: argparse.ArgumentParser) -> None:
         db_group = parser.add_argument_group("database")
         db_group.add_argument(
             "-d",
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index afd65fecd3..510b647c63 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -137,33 +137,14 @@ class EmailConfig(Config):
             if self.root.registration.account_threepid_delegate_email
             else ThreepidBehaviour.LOCAL
         )
-        # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would
-        # use an identity server to password reset tokens on its behalf. We now warn the user
-        # if they have this set and tell them to use the updated option, while using a default
-        # identity server in the process.
-        self.using_identity_server_from_trusted_list = False
-        if (
-            not self.root.registration.account_threepid_delegate_email
-            and config.get("trust_identity_server_for_password_resets", False) is True
-        ):
-            # Use the first entry in self.trusted_third_party_id_servers instead
-            if self.trusted_third_party_id_servers:
-                # XXX: It's a little confusing that account_threepid_delegate_email is modified
-                # both in RegistrationConfig and here. We should factor this bit out
 
-                first_trusted_identity_server = self.trusted_third_party_id_servers[0]
-
-                # trusted_third_party_id_servers does not contain a scheme whereas
-                # account_threepid_delegate_email is expected to. Presume https
-                self.root.registration.account_threepid_delegate_email = (
-                    "https://" + first_trusted_identity_server
-                )
-                self.using_identity_server_from_trusted_list = True
-            else:
-                raise ConfigError(
-                    "Attempted to use an identity server from"
-                    '"trusted_third_party_id_servers" but it is empty.'
-                )
+        if config.get("trust_identity_server_for_password_resets"):
+            raise ConfigError(
+                'The config option "trust_identity_server_for_password_resets" '
+                'has been replaced by "account_threepid_delegate". '
+                "Please consult the sample config at docs/sample_config.yaml for "
+                "details and update your config file."
+            )
 
         self.local_threepid_handling_disabled_due_to_email_config = False
         if (
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 593195ae90..e481fc16b6 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -47,6 +47,9 @@ class ExperimentalConfig(Config):
         # MSC3266 (room summary api)
         self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
 
+        # MSC3030 (Jump to date API endpoint)
+        self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
+
         # MSC2409 (this setting only relates to optionally sending to-device messages).
         # Presence, typing and read receipt EDUs are already sent to application services that
         # have opted in to receive them. This setting, if enabled, adds to-device messages
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 9d295f5856..24c3ef01fc 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -31,6 +31,8 @@ class JWTConfig(Config):
             self.jwt_secret = jwt_config["secret"]
             self.jwt_algorithm = jwt_config["algorithm"]
 
+            self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
+
             # The issuer and audiences are optional, if provided, it is asserted
             # that the claims exist on the JWT.
             self.jwt_issuer = jwt_config.get("issuer")
@@ -46,6 +48,7 @@ class JWTConfig(Config):
             self.jwt_enabled = False
             self.jwt_secret = None
             self.jwt_algorithm = None
+            self.jwt_subject_claim = None
             self.jwt_issuer = None
             self.jwt_audiences = None
 
@@ -88,6 +91,12 @@ class JWTConfig(Config):
             #
             #algorithm: "provided-by-your-issuer"
 
+            # Name of the claim containing a unique identifier for the user.
+            #
+            # Optional, defaults to `sub`.
+            #
+            #subject_claim: "sub"
+
             # The issuer to validate the "iss" claim against.
             #
             # Optional, if provided the "iss" claim will be required and
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 015dbb8a67..035ee2416b 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,6 +16,7 @@
 import hashlib
 import logging
 import os
+from typing import Any, Dict
 
 import attr
 import jsonschema
@@ -312,7 +313,7 @@ class KeyConfig(Config):
                 )
         return keys
 
-    def generate_files(self, config, config_dir_path):
+    def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
         if "signing_key" in config:
             return
 
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 5252e61a99..ea69b9bd9b 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,7 +19,7 @@ import os
 import sys
 import threading
 from string import Template
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 import yaml
 from zope.interface import implementer
@@ -40,6 +41,7 @@ from synapse.util.versionstring import get_version_string
 from ._base import Config, ConfigError
 
 if TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
     from synapse.server import HomeServer
 
 DEFAULT_LOG_CONFIG = Template(
@@ -141,13 +143,13 @@ removed in Synapse 1.3.0. You should instead set up a separate log configuration
 class LoggingConfig(Config):
     section = "logging"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         if config.get("log_file"):
             raise ConfigError(LOG_FILE_ERROR)
         self.log_config = self.abspath(config.get("log_config"))
         self.no_redirect_stdio = config.get("no_redirect_stdio", False)
 
-    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+    def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
         log_config = os.path.join(config_dir_path, server_name + ".log.config")
         return (
             """\
@@ -161,14 +163,14 @@ class LoggingConfig(Config):
             % locals()
         )
 
-    def read_arguments(self, args):
+    def read_arguments(self, args: argparse.Namespace) -> None:
         if args.no_redirect_stdio is not None:
             self.no_redirect_stdio = args.no_redirect_stdio
         if args.log_file is not None:
             raise ConfigError(LOG_FILE_ERROR)
 
     @staticmethod
-    def add_arguments(parser):
+    def add_arguments(parser: argparse.ArgumentParser) -> None:
         logging_group = parser.add_argument_group("logging")
         logging_group.add_argument(
             "-n",
@@ -185,7 +187,7 @@ class LoggingConfig(Config):
             help=argparse.SUPPRESS,
         )
 
-    def generate_files(self, config, config_dir_path):
+    def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
         log_config = config.get("log_config")
         if log_config and not os.path.exists(log_config):
             log_file = self.abspath("homeserver.log")
@@ -197,7 +199,9 @@ class LoggingConfig(Config):
                 log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
 
 
-def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
+def _setup_stdlib_logging(
+    config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner
+) -> None:
     """
     Set up Python standard library logging.
     """
@@ -230,7 +234,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
     log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
     old_factory = logging.getLogRecordFactory()
 
-    def factory(*args, **kwargs):
+    def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
         record = old_factory(*args, **kwargs)
         log_context_filter.filter(record)
         log_metadata_filter.filter(record)
@@ -297,7 +301,7 @@ def _load_logging_config(log_config_path: str) -> None:
     logging.config.dictConfig(log_config)
 
 
-def _reload_logging_config(log_config_path):
+def _reload_logging_config(log_config_path: Optional[str]) -> None:
     """
     Reload the log configuration from the file and apply it.
     """
@@ -311,8 +315,8 @@ def _reload_logging_config(log_config_path):
 
 def setup_logging(
     hs: "HomeServer",
-    config,
-    use_worker_options=False,
+    config: "HomeServerConfig",
+    use_worker_options: bool = False,
     logBeginner: LogBeginner = globalLogBeginner,
 ) -> None:
     """
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 42f113cd24..79c400fe30 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from collections import Counter
-from typing import Collection, Iterable, List, Mapping, Optional, Tuple, Type
+from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type
 
 import attr
 
@@ -36,7 +36,7 @@ LEGACY_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingPr
 class OIDCConfig(Config):
     section = "oidc"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
         if not self.oidc_providers:
             return
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
         # OIDC is enabled if we have a provider
         return bool(self.oidc_providers)
 
-    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+    def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
         return """\
         # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
         # and login.
@@ -495,89 +495,89 @@ def _parse_oidc_config_dict(
     )
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class OidcProviderClientSecretJwtKey:
     # a pem-encoded signing key
-    key = attr.ib(type=str)
+    key: str
 
     # properties to include in the JWT header
-    jwt_header = attr.ib(type=Mapping[str, str])
+    jwt_header: Mapping[str, str]
 
     # properties to include in the JWT payload.
-    jwt_payload = attr.ib(type=Mapping[str, str])
+    jwt_payload: Mapping[str, str]
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class OidcProviderConfig:
     # a unique identifier for this identity provider. Used in the 'user_external_ids'
     # table, as well as the query/path parameter used in the login protocol.
-    idp_id = attr.ib(type=str)
+    idp_id: str
 
     # user-facing name for this identity provider.
-    idp_name = attr.ib(type=str)
+    idp_name: str
 
     # Optional MXC URI for icon for this IdP.
-    idp_icon = attr.ib(type=Optional[str])
+    idp_icon: Optional[str]
 
     # Optional brand identifier for this IdP.
-    idp_brand = attr.ib(type=Optional[str])
+    idp_brand: Optional[str]
 
     # whether the OIDC discovery mechanism is used to discover endpoints
-    discover = attr.ib(type=bool)
+    discover: bool
 
     # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
     # discover the provider's endpoints.
-    issuer = attr.ib(type=str)
+    issuer: str
 
     # oauth2 client id to use
-    client_id = attr.ib(type=str)
+    client_id: str
 
     # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
     # a secret.
-    client_secret = attr.ib(type=Optional[str])
+    client_secret: Optional[str]
 
     # key to use to construct a JWT to use as a client secret. May be `None` if
     # `client_secret` is set.
-    client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
+    client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey]
 
     # auth method to use when exchanging the token.
     # Valid values are 'client_secret_basic', 'client_secret_post' and
     # 'none'.
-    client_auth_method = attr.ib(type=str)
+    client_auth_method: str
 
     # list of scopes to request
-    scopes = attr.ib(type=Collection[str])
+    scopes: Collection[str]
 
     # the oauth2 authorization endpoint. Required if discovery is disabled.
-    authorization_endpoint = attr.ib(type=Optional[str])
+    authorization_endpoint: Optional[str]
 
     # the oauth2 token endpoint. Required if discovery is disabled.
-    token_endpoint = attr.ib(type=Optional[str])
+    token_endpoint: Optional[str]
 
     # the OIDC userinfo endpoint. Required if discovery is disabled and the
     # "openid" scope is not requested.
-    userinfo_endpoint = attr.ib(type=Optional[str])
+    userinfo_endpoint: Optional[str]
 
     # URI where to fetch the JWKS. Required if discovery is disabled and the
     # "openid" scope is used.
-    jwks_uri = attr.ib(type=Optional[str])
+    jwks_uri: Optional[str]
 
     # Whether to skip metadata verification
-    skip_verification = attr.ib(type=bool)
+    skip_verification: bool
 
     # Whether to fetch the user profile from the userinfo endpoint. Valid
     # values are: "auto" or "userinfo_endpoint".
-    user_profile_method = attr.ib(type=str)
+    user_profile_method: str
 
     # whether to allow a user logging in via OIDC to match a pre-existing account
     # instead of failing
-    allow_existing_users = attr.ib(type=bool)
+    allow_existing_users: bool
 
     # the class of the user mapping provider
-    user_mapping_provider_class = attr.ib(type=Type)
+    user_mapping_provider_class: Type
 
     # the config of the user mapping provider
-    user_mapping_provider_config = attr.ib()
+    user_mapping_provider_config: Any
 
     # required attributes to require in userinfo to allow login/registration
-    attribute_requirements = attr.ib(type=List[SsoAttributeRequirement])
+    attribute_requirements: List[SsoAttributeRequirement]
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 5379e80715..7a059c6dec 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -1,4 +1,5 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -11,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import argparse
+from typing import Optional
 
 from synapse.api.constants import RoomCreationPreset
 from synapse.config._base import Config, ConfigError
@@ -39,9 +42,7 @@ class RegistrationConfig(Config):
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
-        self.trusted_third_party_id_servers = config.get(
-            "trusted_third_party_id_servers", ["matrix.org", "vector.im"]
-        )
+
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
         self.account_threepid_delegate_email = account_threepid_delegates.get("email")
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
@@ -114,26 +115,74 @@ class RegistrationConfig(Config):
             session_lifetime = self.parse_duration(session_lifetime)
         self.session_lifetime = session_lifetime
 
-        # The `access_token_lifetime` applies for tokens that can be renewed
-        # using a refresh token, as per MSC2918. If it is `None`, the refresh
-        # token mechanism is disabled.
-        #
-        # Since it is incompatible with the `session_lifetime` mechanism, it is set to
-        # `None` by default if a `session_lifetime` is set.
-        access_token_lifetime = config.get(
-            "access_token_lifetime", "5m" if session_lifetime is None else None
+        # The `refreshable_access_token_lifetime` applies for tokens that can be renewed
+        # using a refresh token, as per MSC2918.
+        # If it is `None`, the refresh token mechanism is disabled.
+        refreshable_access_token_lifetime = config.get(
+            "refreshable_access_token_lifetime",
+            "5m",
+        )
+        if refreshable_access_token_lifetime is not None:
+            refreshable_access_token_lifetime = self.parse_duration(
+                refreshable_access_token_lifetime
+            )
+        self.refreshable_access_token_lifetime: Optional[
+            int
+        ] = refreshable_access_token_lifetime
+
+        if (
+            self.session_lifetime is not None
+            and "refreshable_access_token_lifetime" in config
+        ):
+            if self.session_lifetime < self.refreshable_access_token_lifetime:
+                raise ConfigError(
+                    "Both `session_lifetime` and `refreshable_access_token_lifetime` "
+                    "configuration options have been set, but `refreshable_access_token_lifetime` "
+                    " exceeds `session_lifetime`!"
+                )
+
+        # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be
+        # refreshed using a refresh token.
+        # If it is None, then these tokens last for the entire length of the session,
+        # which is infinite by default.
+        # The intention behind this configuration option is to help with requiring
+        # all clients to use refresh tokens, if the homeserver administrator requires.
+        nonrefreshable_access_token_lifetime = config.get(
+            "nonrefreshable_access_token_lifetime",
+            None,
         )
-        if access_token_lifetime is not None:
-            access_token_lifetime = self.parse_duration(access_token_lifetime)
-        self.access_token_lifetime = access_token_lifetime
-
-        if session_lifetime is not None and access_token_lifetime is not None:
-            raise ConfigError(
-                "The refresh token mechanism is incompatible with the "
-                "`session_lifetime` option. Consider disabling the "
-                "`session_lifetime` option or disabling the refresh token "
-                "mechanism by removing the `access_token_lifetime` option."
+        if nonrefreshable_access_token_lifetime is not None:
+            nonrefreshable_access_token_lifetime = self.parse_duration(
+                nonrefreshable_access_token_lifetime
             )
+        self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime
+
+        if (
+            self.session_lifetime is not None
+            and self.nonrefreshable_access_token_lifetime is not None
+        ):
+            if self.session_lifetime < self.nonrefreshable_access_token_lifetime:
+                raise ConfigError(
+                    "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` "
+                    "configuration options have been set, but `nonrefreshable_access_token_lifetime` "
+                    " exceeds `session_lifetime`!"
+                )
+
+        refresh_token_lifetime = config.get("refresh_token_lifetime")
+        if refresh_token_lifetime is not None:
+            refresh_token_lifetime = self.parse_duration(refresh_token_lifetime)
+        self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime
+
+        if (
+            self.session_lifetime is not None
+            and self.refresh_token_lifetime is not None
+        ):
+            if self.session_lifetime < self.refresh_token_lifetime:
+                raise ConfigError(
+                    "Both `session_lifetime` and `refresh_token_lifetime` "
+                    "configuration options have been set, but `refresh_token_lifetime` "
+                    " exceeds `session_lifetime`!"
+                )
 
         # The fallback template used for authenticating using a registration token
         self.registration_token_template = self.read_template("registration_token.html")
@@ -171,6 +220,44 @@ class RegistrationConfig(Config):
         #
         #session_lifetime: 24h
 
+        # Time that an access token remains valid for, if the session is
+        # using refresh tokens.
+        # For more information about refresh tokens, please see the manual.
+        # Note that this only applies to clients which advertise support for
+        # refresh tokens.
+        #
+        # Note also that this is calculated at login time and refresh time:
+        # changes are not applied to existing sessions until they are refreshed.
+        #
+        # By default, this is 5 minutes.
+        #
+        #refreshable_access_token_lifetime: 5m
+
+        # Time that a refresh token remains valid for (provided that it is not
+        # exchanged for another one first).
+        # This option can be used to automatically log-out inactive sessions.
+        # Please see the manual for more information.
+        #
+        # Note also that this is calculated at login time and refresh time:
+        # changes are not applied to existing sessions until they are refreshed.
+        #
+        # By default, this is infinite.
+        #
+        #refresh_token_lifetime: 24h
+
+        # Time that an access token remains valid for, if the session is NOT
+        # using refresh tokens.
+        # Please note that not all clients support refresh tokens, so setting
+        # this to a short value may be inconvenient for some users who will
+        # then be logged out frequently.
+        #
+        # Note also that this is calculated at login time: changes are not applied
+        # retrospectively to existing sessions for users that have already logged in.
+        #
+        # By default, this is infinite.
+        #
+        #nonrefreshable_access_token_lifetime: 24h
+
         # The user must provide all of the below types of 3PID when registering.
         #
         #registrations_require_3pid:
@@ -364,7 +451,7 @@ class RegistrationConfig(Config):
         )
 
     @staticmethod
-    def add_arguments(parser):
+    def add_arguments(parser: argparse.ArgumentParser) -> None:
         reg_group = parser.add_argument_group("registration")
         reg_group.add_argument(
             "--enable-registration",
@@ -373,6 +460,6 @@ class RegistrationConfig(Config):
             help="Enable registration for new users.",
         )
 
-    def read_arguments(self, args):
+    def read_arguments(self, args: argparse.Namespace) -> None:
         if args.enable_registration is not None:
             self.enable_registration = strtobool(str(args.enable_registration))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 69906a98d4..b129b9dd68 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -15,11 +15,12 @@
 import logging
 import os
 from collections import namedtuple
-from typing import Dict, List
+from typing import Dict, List, Tuple
 from urllib.request import getproxies_environment  # type: ignore
 
 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import JsonDict
 from synapse.util.module_loader import load_module
 
 from ._base import Config, ConfigError
@@ -57,7 +58,9 @@ MediaStorageProviderConfig = namedtuple(
 )
 
 
-def parse_thumbnail_requirements(thumbnail_sizes):
+def parse_thumbnail_requirements(
+    thumbnail_sizes: List[JsonDict],
+) -> Dict[str, Tuple[ThumbnailRequirement, ...]]:
     """Takes a list of dictionaries with "width", "height", and "method" keys
     and creates a map from image media types to the thumbnail size, thumbnailing
     method, and thumbnail media type to precalculate
@@ -69,7 +72,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
         Dictionary mapping from media type string to list of
         ThumbnailRequirement tuples.
     """
-    requirements: Dict[str, List] = {}
+    requirements: Dict[str, List[ThumbnailRequirement]] = {}
     for size in thumbnail_sizes:
         width = size["width"]
         height = size["height"]
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 56981cac79..3c5e0f7ce7 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -1,4 +1,5 @@
 # Copyright 2018 New Vector Ltd
+# Copyright 2021 Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util import glob_to_regex
+from typing import List
+
+from matrix_common.regex import glob_to_regex
+
+from synapse.types import JsonDict
 
 from ._base import Config, ConfigError
 
@@ -20,7 +25,7 @@ from ._base import Config, ConfigError
 class RoomDirectoryConfig(Config):
     section = "roomdirectory"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         self.enable_room_list_search = config.get("enable_room_list_search", True)
 
         alias_creation_rules = config.get("alias_creation_rules")
@@ -47,7 +52,7 @@ class RoomDirectoryConfig(Config):
                 _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
             ]
 
-    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+    def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
         return """
         # Uncomment to disable searching the public room list. When disabled
         # blocks searching local and remote room lists for local and remote
@@ -113,16 +118,16 @@ class RoomDirectoryConfig(Config):
         #    action: allow
         """
 
-    def is_alias_creation_allowed(self, user_id, room_id, alias):
+    def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool:
         """Checks if the given user is allowed to create the given alias
 
         Args:
-            user_id (str)
-            room_id (str)
-            alias (str)
+            user_id: The user to check.
+            room_id: The room ID for the alias.
+            alias: The alias being created.
 
         Returns:
-            boolean: True if user is allowed to create the alias
+            True if user is allowed to create the alias
         """
         for rule in self._alias_creation_rules:
             if rule.matches(user_id, room_id, [alias]):
@@ -130,16 +135,18 @@ class RoomDirectoryConfig(Config):
 
         return False
 
-    def is_publishing_room_allowed(self, user_id, room_id, aliases):
+    def is_publishing_room_allowed(
+        self, user_id: str, room_id: str, aliases: List[str]
+    ) -> bool:
         """Checks if the given user is allowed to publish the room
 
         Args:
-            user_id (str)
-            room_id (str)
-            aliases (list[str]): any local aliases associated with the room
+            user_id: The user ID publishing the room.
+            room_id: The room being published.
+            aliases: any local aliases associated with the room
 
         Returns:
-            boolean: True if user can publish room
+            True if user can publish room
         """
         for rule in self._room_list_publication_rules:
             if rule.matches(user_id, room_id, aliases):
@@ -153,11 +160,11 @@ class _RoomDirectoryRule:
     creating an alias or publishing a room.
     """
 
-    def __init__(self, option_name, rule):
+    def __init__(self, option_name: str, rule: JsonDict):
         """
         Args:
-            option_name (str): Name of the config option this rule belongs to
-            rule (dict): The rule as specified in the config
+            option_name: Name of the config option this rule belongs to
+            rule: The rule as specified in the config
         """
 
         action = rule["action"]
@@ -181,18 +188,18 @@ class _RoomDirectoryRule:
         except Exception as e:
             raise ConfigError("Failed to parse glob into regex") from e
 
-    def matches(self, user_id, room_id, aliases):
+    def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
         """Tests if this rule matches the given user_id, room_id and aliases.
 
         Args:
-            user_id (str)
-            room_id (str)
-            aliases (list[str]): The associated aliases to the room. Will be a
-                single element for testing alias creation, and can be empty for
-                testing room publishing.
+            user_id: The user ID to check.
+            room_id: The room ID to check.
+            aliases: The associated aliases to the room. Will be a single element
+                for testing alias creation, and can be empty for testing room
+                publishing.
 
         Returns:
-            boolean
+            True if the rule matches.
         """
 
         # Note: The regexes are anchored at both ends
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index ba2b0905ff..ec9d9f65e7 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -1,5 +1,5 @@
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +14,11 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List
+from typing import Any, List, Set
 
 from synapse.config.sso import SsoAttributeRequirement
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import JsonDict
 from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
@@ -33,7 +34,7 @@ LEGACY_USER_MAPPING_PROVIDER = (
 )
 
 
-def _dict_merge(merge_dict, into_dict):
+def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
     """Do a deep merge of two dicts
 
     Recursively merges `merge_dict` into `into_dict`:
@@ -43,8 +44,8 @@ def _dict_merge(merge_dict, into_dict):
         the value from `merge_dict`.
 
     Args:
-        merge_dict (dict): dict to merge
-        into_dict (dict): target dict
+        merge_dict: dict to merge
+        into_dict: target dict to be modified
     """
     for k, v in merge_dict.items():
         if k not in into_dict:
@@ -64,7 +65,7 @@ def _dict_merge(merge_dict, into_dict):
 class SAML2Config(Config):
     section = "saml2"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         self.saml2_enabled = False
 
         saml2_config = config.get("saml2_config")
@@ -183,8 +184,8 @@ class SAML2Config(Config):
         )
 
     def _default_saml_config_dict(
-        self, required_attributes: set, optional_attributes: set
-    ):
+        self, required_attributes: Set[str], optional_attributes: Set[str]
+    ) -> JsonDict:
         """Generate a configuration dictionary with required and optional attributes that
         will be needed to process new user registration
 
@@ -195,7 +196,7 @@ class SAML2Config(Config):
                 additional information to Synapse user accounts, but are not required
 
         Returns:
-            dict: A SAML configuration dictionary
+            A SAML configuration dictionary
         """
         import saml2
 
@@ -222,7 +223,7 @@ class SAML2Config(Config):
             },
         }
 
-    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+    def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
         return """\
         ## Single sign-on integration ##
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7bc0030a9e..ba5b954263 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import argparse
 import itertools
 import logging
 import os.path
@@ -27,6 +28,7 @@ from netaddr import AddrFormatError, IPNetwork, IPSet
 from twisted.conch.ssh.keys import Key
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.types import JsonDict
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_server_name
 
@@ -421,7 +423,7 @@ class ServerConfig(Config):
         # before redacting them.
         redaction_retention_period = config.get("redaction_retention_period", "7d")
         if redaction_retention_period is not None:
-            self.redaction_retention_period = self.parse_duration(
+            self.redaction_retention_period: Optional[int] = self.parse_duration(
                 redaction_retention_period
             )
         else:
@@ -430,7 +432,7 @@ class ServerConfig(Config):
         # How long to keep entries in the `users_ips` table.
         user_ips_max_age = config.get("user_ips_max_age", "28d")
         if user_ips_max_age is not None:
-            self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+            self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
         else:
             self.user_ips_max_age = None
 
@@ -1223,7 +1225,7 @@ class ServerConfig(Config):
             % locals()
         )
 
-    def read_arguments(self, args):
+    def read_arguments(self, args: argparse.Namespace) -> None:
         if args.manhole is not None:
             self.manhole = args.manhole
         if args.daemonize is not None:
@@ -1232,7 +1234,7 @@ class ServerConfig(Config):
             self.print_pidfile = args.print_pidfile
 
     @staticmethod
-    def add_arguments(parser):
+    def add_arguments(parser: argparse.ArgumentParser) -> None:
         server_group = parser.add_argument_group("server")
         server_group.add_argument(
             "-D",
@@ -1274,14 +1276,16 @@ class ServerConfig(Config):
             )
 
 
-def is_threepid_reserved(reserved_threepids, threepid):
+def is_threepid_reserved(
+    reserved_threepids: List[JsonDict], threepid: JsonDict
+) -> bool:
     """Check the threepid against the reserved threepid config
     Args:
-        reserved_threepids([dict]) - list of reserved threepids
-        threepid(dict) - The threepid to test for
+        reserved_threepids: List of reserved threepids
+        threepid: The threepid to test for
 
     Returns:
-        boolean Is the threepid undertest reserved_user
+        Is the threepid undertest reserved_user
     """
 
     for tp in reserved_threepids:
@@ -1290,7 +1294,9 @@ def is_threepid_reserved(reserved_threepids, threepid):
     return False
 
 
-def read_gc_thresholds(thresholds):
+def read_gc_thresholds(
+    thresholds: Optional[List[Any]],
+) -> Optional[Tuple[int, int, int]]:
     """Reads the three integer thresholds for garbage collection. Ensures that
     the thresholds are integers if thresholds are supplied.
     """
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 60aacb13ea..e4a4243261 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,13 +29,13 @@ https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
 
-@attr.s(frozen=True)
+@attr.s(frozen=True, auto_attribs=True)
 class SsoAttributeRequirement:
     """Object describing a single requirement for SSO attributes."""
 
-    attribute = attr.ib(type=str)
+    attribute: str
     # If a value is not given, than the attribute must simply exist.
-    value = attr.ib(type=Optional[str])
+    value: Optional[str]
 
     JSON_SCHEMA = {
         "type": "object",
@@ -49,7 +49,7 @@ class SSOConfig(Config):
 
     section = "sso"
 
-    def read_config(self, config, **kwargs):
+    def read_config(self, config, **kwargs) -> None:
         sso_config: Dict[str, Any] = config.get("sso") or {}
 
         # The sso-specific template_dir
@@ -106,7 +106,7 @@ class SSOConfig(Config):
         )
         self.sso_client_whitelist.append(login_fallback_url)
 
-    def generate_config_section(self, **kwargs):
+    def generate_config_section(self, **kwargs) -> str:
         return """\
         # Additional settings to use with single-sign on systems such as OpenID Connect,
         # SAML2 and CAS.
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 6227434bac..3e235b57a7 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -14,14 +14,14 @@
 
 import logging
 import os
-from datetime import datetime
 from typing import List, Optional, Pattern
 
+from matrix_common.regex import glob_to_regex
+
 from OpenSSL import SSL, crypto
 from twisted.internet._sslverify import Certificate, trustRootFromCertificates
 
 from synapse.config._base import Config, ConfigError
-from synapse.util import glob_to_regex
 
 logger = logging.getLogger(__name__)
 
@@ -133,55 +133,6 @@ class TlsConfig(Config):
         self.tls_certificate: Optional[crypto.X509] = None
         self.tls_private_key: Optional[crypto.PKey] = None
 
-    def is_disk_cert_valid(self, allow_self_signed=True):
-        """
-        Is the certificate we have on disk valid, and if so, for how long?
-
-        Args:
-            allow_self_signed (bool): Should we allow the certificate we
-                read to be self signed?
-
-        Returns:
-            int: Days remaining of certificate validity.
-            None: No certificate exists.
-        """
-        if not os.path.exists(self.tls_certificate_file):
-            return None
-
-        try:
-            with open(self.tls_certificate_file, "rb") as f:
-                cert_pem = f.read()
-        except Exception as e:
-            raise ConfigError(
-                "Failed to read existing certificate file %s: %s"
-                % (self.tls_certificate_file, e)
-            )
-
-        try:
-            tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
-        except Exception as e:
-            raise ConfigError(
-                "Failed to parse existing certificate file %s: %s"
-                % (self.tls_certificate_file, e)
-            )
-
-        if not allow_self_signed:
-            if tls_certificate.get_subject() == tls_certificate.get_issuer():
-                raise ValueError(
-                    "TLS Certificate is self signed, and this is not permitted"
-                )
-
-        # YYYYMMDDhhmmssZ -- in UTC
-        expiry_data = tls_certificate.get_notAfter()
-        if expiry_data is None:
-            raise ValueError(
-                "TLS Certificate has no expiry date, and this is not permitted"
-            )
-        expires_on = datetime.strptime(expiry_data.decode("ascii"), "%Y%m%d%H%M%SZ")
-        now = datetime.utcnow()
-        days_remaining = (expires_on - now).days
-        return days_remaining
-
     def read_certificate_from_disk(self):
         """
         Read the certificates and private key from disk.
@@ -263,8 +214,8 @@ class TlsConfig(Config):
         #
         #federation_certificate_verification_whitelist:
         #  - lon.example.com
-        #  - *.domain.com
-        #  - *.onion
+        #  - "*.domain.com"
+        #  - "*.onion"
 
         # List of custom certificate authorities for federation traffic.
         #
@@ -295,7 +246,7 @@ class TlsConfig(Config):
         cert_path = self.tls_certificate_file
         logger.info("Loading TLS certificate from %s", cert_path)
         cert_pem = self.read_file(cert_path, "tls_certificate_path")
-        cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+        cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode())
 
         return cert
 
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 2552f688d0..6d6678c7e4 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -53,8 +53,8 @@ class UserDirectoryConfig(Config):
             # indexes were (re)built was before Synapse 1.44, you'll have to
             # rebuild the indexes in order to search through all known users.
             # These indexes are built the first time Synapse starts; admins can
-            # manually trigger a rebuild following the instructions at
-            #     https://matrix-org.github.io/synapse/latest/user_directory.html
+            # manually trigger a rebuild via API following the instructions at
+            #     https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run
             #
             # Uncomment to return search results containing all known users, even if that
             # user does not share a room with the requester.
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 4507992031..576f519188 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import argparse
 from typing import List, Union
 
 import attr
@@ -343,7 +345,7 @@ class WorkerConfig(Config):
         #worker_replication_secret: ""
         """
 
-    def read_arguments(self, args):
+    def read_arguments(self, args: argparse.Namespace) -> None:
         # We support a bunch of command line arguments that override options in
         # the config. A lot of these options have a worker_* prefix when running
         # on workers so we also have to override them when command line options
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..993b04099e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
             key_ids=key_ids,
         )
 
-    def to_fetch_key_request(self) -> "_FetchKeyRequest":
-        """Create a key fetch request for all keys needed to satisfy the
-        verification request.
-        """
-        return _FetchKeyRequest(
-            server_name=self.server_name,
-            minimum_valid_until_ts=self.minimum_valid_until_ts,
-            key_ids=self.key_ids,
-        )
-
 
 class KeyLookupError(ValueError):
     pass
@@ -179,8 +168,22 @@ class Keyring:
             clock=hs.get_clock(),
             process_batch_callback=self._inner_fetch_key_requests,
         )
-        self.verify_key = get_verify_key(hs.signing_key)
-        self.hostname = hs.hostname
+
+        self._hostname = hs.hostname
+
+        # build a FetchKeyResult for each of our own keys, to shortcircuit the
+        # fetcher.
+        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+        for key_id, key in hs.config.key.old_signing_keys.items():
+            self._local_verify_keys[key_id] = FetchKeyResult(
+                verify_key=key, valid_until_ts=key.expired_ts
+            )
+
+        vk = get_verify_key(hs.signing_key)
+        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+            verify_key=vk,
+            valid_until_ts=2 ** 63,  # fake future timestamp
+        )
 
     async def verify_json_for_server(
         self,
@@ -267,22 +270,32 @@ class Keyring:
                 Codes.UNAUTHORIZED,
             )
 
-        # If we are the originating server don't fetch verify key for self over federation
-        if verify_request.server_name == self.hostname:
-            await self._process_json(self.verify_key, verify_request)
-            return
+        found_keys: Dict[str, FetchKeyResult] = {}
 
-        # Add the keys we need to verify to the queue for retrieval. We queue
-        # up requests for the same server so we don't end up with many in flight
-        # requests for the same keys.
-        key_request = verify_request.to_fetch_key_request()
-        found_keys_by_server = await self._server_queue.add_to_queue(
-            key_request, key=verify_request.server_name
-        )
+        # If we are the originating server, short-circuit the key-fetch for any keys
+        # we already have
+        if verify_request.server_name == self._hostname:
+            for key_id in verify_request.key_ids:
+                if key_id in self._local_verify_keys:
+                    found_keys[key_id] = self._local_verify_keys[key_id]
+
+        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+        if key_ids_to_find:
+            # Add the keys we need to verify to the queue for retrieval. We queue
+            # up requests for the same server so we don't end up with many in flight
+            # requests for the same keys.
+            key_request = _FetchKeyRequest(
+                server_name=verify_request.server_name,
+                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+                key_ids=list(key_ids_to_find),
+            )
+            found_keys_by_server = await self._server_queue.add_to_queue(
+                key_request, key=verify_request.server_name
+            )
 
-        # Since we batch up requests the returned set of keys may contain keys
-        # from other servers, so we pull out only the ones we care about.s
-        found_keys = found_keys_by_server.get(verify_request.server_name, {})
+            # Since we batch up requests the returned set of keys may contain keys
+            # from other servers, so we pull out only the ones we care about.
+            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
 
         # Verify each signature we got valid keys for, raising if we can't
         # verify any of them.
@@ -654,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             perspective_name,
         )
 
+        request: JsonDict = {}
+        for queue_value in keys_to_fetch:
+            # there may be multiple requests for each server, so we have to merge
+            # them intelligently.
+            request_for_server = {
+                key_id: {
+                    "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                }
+                for key_id in queue_value.key_ids
+            }
+            request.setdefault(queue_value.server_name, {}).update(request_for_server)
+
+        logger.debug("Request to notary server %s: %s", perspective_name, request)
+
         try:
             query_response = await self.client.post_json(
                 destination=perspective_name,
                 path="/_matrix/key/v2/query",
-                data={
-                    "server_keys": {
-                        queue_value.server_name: {
-                            key_id: {
-                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
-                            }
-                            for key_id in queue_value.key_ids
-                        }
-                        for queue_value in keys_to_fetch
-                    }
-                },
+                data={"server_keys": request},
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             # these both have str() representations which we can't really improve upon
@@ -676,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
+        logger.debug(
+            "Response from notary server %s: %s", perspective_name, query_response
+        )
+
         keys: Dict[str, Dict[str, FetchKeyResult]] = {}
         added_keys: List[Tuple[str, str, FetchKeyResult]] = []
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index d7527008c4..f251402ed8 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -322,6 +322,11 @@ class _AsyncEventContextImpl(EventContext):
         attributes by loading from the database.
         """
         if self.state_group is None:
+            # No state group means the event is an outlier. Usually the state_ids dicts are also
+            # pre-set to empty dicts, but they get reset when the context is serialized, so set
+            # them to empty dicts again here.
+            self._current_state_ids = {}
+            self._prev_state_ids = {}
             return
 
         current_state_ids = await self._storage.state.get_state_ids_for_group(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 6fa631aa1d..84ef69df67 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -305,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
 def serialize_event(
     e: Union[JsonDict, EventBase],
     time_now_ms: int,
+    *,
     as_client_event: bool = True,
     event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
     token_id: Optional[str] = None,
@@ -392,15 +394,18 @@ class EventClientSerializer:
         self,
         event: Union[JsonDict, EventBase],
         time_now: int,
+        *,
         bundle_aggregations: bool = True,
         **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
 
         Args:
-            event
+            event: The event being serialized.
             time_now: The current time in milliseconds
-            bundle_aggregations: Whether to bundle in related events
+            bundle_aggregations: Whether to include the bundled aggregations for this
+                event. Only applies to non-state events. (State events never include
+                bundled aggregations.)
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
@@ -410,76 +415,109 @@ class EventClientSerializer:
         if not isinstance(event, EventBase):
             return event
 
-        event_id = event.event_id
         serialized_event = serialize_event(event, time_now, **kwargs)
 
-        # If MSC1849 is enabled then we need to look if there are any relations
-        # we need to bundle in with the event.
-        # Do not bundle relations if the event has been redacted
-        if not event.internal_metadata.is_redacted() and (
-            self._msc1849_enabled and bundle_aggregations
+        # Check if there are any bundled aggregations to include with the event.
+        #
+        # Do not bundle aggregations if any of the following at true:
+        #
+        # * Support is disabled via the configuration or the caller.
+        # * The event is a state event.
+        # * The event has been redacted.
+        if (
+            self._msc1849_enabled
+            and bundle_aggregations
+            and not event.is_state()
+            and not event.internal_metadata.is_redacted()
         ):
-            annotations = await self.store.get_aggregation_groups_for_event(event_id)
-            references = await self.store.get_relations_for_event(
-                event_id, RelationTypes.REFERENCE, direction="f"
-            )
+            await self._injected_bundled_aggregations(event, time_now, serialized_event)
 
-            if annotations.chunk:
-                r = serialized_event["unsigned"].setdefault("m.relations", {})
-                r[RelationTypes.ANNOTATION] = annotations.to_dict()
-
-            if references.chunk:
-                r = serialized_event["unsigned"].setdefault("m.relations", {})
-                r[RelationTypes.REFERENCE] = references.to_dict()
-
-            edit = None
-            if event.type == EventTypes.Message:
-                edit = await self.store.get_applicable_edit(event_id)
-
-            if edit:
-                # If there is an edit replace the content, preserving existing
-                # relations.
-
-                # Ensure we take copies of the edit content, otherwise we risk modifying
-                # the original event.
-                edit_content = edit.content.copy()
-
-                # Unfreeze the event content if necessary, so that we may modify it below
-                edit_content = unfreeze(edit_content)
-                serialized_event["content"] = edit_content.get("m.new_content", {})
-
-                # Check for existing relations
-                relations = event.content.get("m.relates_to")
-                if relations:
-                    # Keep the relations, ensuring we use a dict copy of the original
-                    serialized_event["content"]["m.relates_to"] = relations.copy()
-                else:
-                    serialized_event["content"].pop("m.relates_to", None)
-
-                r = serialized_event["unsigned"].setdefault("m.relations", {})
-                r[RelationTypes.REPLACE] = {
-                    "event_id": edit.event_id,
-                    "origin_server_ts": edit.origin_server_ts,
-                    "sender": edit.sender,
-                }
+        return serialized_event
 
-            # If this event is the start of a thread, include a summary of the replies.
-            if self._msc3440_enabled:
-                (
-                    thread_count,
-                    latest_thread_event,
-                ) = await self.store.get_thread_summary(event_id)
-                if latest_thread_event:
-                    r = serialized_event["unsigned"].setdefault("m.relations", {})
-                    r[RelationTypes.THREAD] = {
-                        # Don't bundle aggregations as this could recurse forever.
-                        "latest_event": await self.serialize_event(
-                            latest_thread_event, time_now, bundle_aggregations=False
-                        ),
-                        "count": thread_count,
-                    }
+    async def _injected_bundled_aggregations(
+        self, event: EventBase, time_now: int, serialized_event: JsonDict
+    ) -> None:
+        """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
 
-        return serialized_event
+        Args:
+            event: The event being serialized.
+            time_now: The current time in milliseconds
+            serialized_event: The serialized event which may be modified.
+
+        """
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return
+
+        event_id = event.event_id
+
+        # The bundled aggregations to include.
+        aggregations = {}
+
+        annotations = await self.store.get_aggregation_groups_for_event(event_id)
+        if annotations.chunk:
+            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+        references = await self.store.get_relations_for_event(
+            event_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+        edit = None
+        if event.type == EventTypes.Message:
+            edit = await self.store.get_applicable_edit(event_id)
+
+        if edit:
+            # If there is an edit replace the content, preserving existing
+            # relations.
+
+            # Ensure we take copies of the edit content, otherwise we risk modifying
+            # the original event.
+            edit_content = edit.content.copy()
+
+            # Unfreeze the event content if necessary, so that we may modify it below
+            edit_content = unfreeze(edit_content)
+            serialized_event["content"] = edit_content.get("m.new_content", {})
+
+            # Check for existing relations
+            relates_to = event.content.get("m.relates_to")
+            if relates_to:
+                # Keep the relations, ensuring we use a dict copy of the original
+                serialized_event["content"]["m.relates_to"] = relates_to.copy()
+            else:
+                serialized_event["content"].pop("m.relates_to", None)
+
+            aggregations[RelationTypes.REPLACE] = {
+                "event_id": edit.event_id,
+                "origin_server_ts": edit.origin_server_ts,
+                "sender": edit.sender,
+            }
+
+        # If this event is the start of a thread, include a summary of the replies.
+        if self._msc3440_enabled:
+            (
+                thread_count,
+                latest_thread_event,
+            ) = await self.store.get_thread_summary(event_id)
+            if latest_thread_event:
+                aggregations[RelationTypes.THREAD] = {
+                    # Don't bundle aggregations as this could recurse forever.
+                    "latest_event": await self.serialize_event(
+                        latest_thread_event, time_now, bundle_aggregations=False
+                    ),
+                    "count": thread_count,
+                }
+
+        # If any bundled aggregations were found, include them.
+        if aggregations:
+            serialized_event["unsigned"].setdefault("m.relations", {}).update(
+                aggregations
+            )
 
     async def serialize_events(
         self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3b85b135e0..fee1477ab6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -128,7 +128,7 @@ class FederationClient(FederationBase):
             reset_expiry_on_get=False,
         )
 
-    def _clear_tried_cache(self):
+    def _clear_tried_cache(self) -> None:
         """Clear pdu_destination_tried cache"""
         now = self._clock.time_msec()
 
@@ -800,7 +800,7 @@ class FederationClient(FederationBase):
                 no servers successfully handle the request.
         """
 
-        async def send_request(destination) -> SendJoinResult:
+        async def send_request(destination: str) -> SendJoinResult:
             response = await self._do_send_join(room_version, destination, pdu)
 
             # If an event was returned (and expected to be returned):
@@ -1395,11 +1395,28 @@ class FederationClient(FederationBase):
         async def send_request(
             destination: str,
         ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
-            res = await self.transport_layer.get_room_hierarchy(
-                destination=destination,
-                room_id=room_id,
-                suggested_only=suggested_only,
-            )
+            try:
+                res = await self.transport_layer.get_room_hierarchy(
+                    destination=destination,
+                    room_id=room_id,
+                    suggested_only=suggested_only,
+                )
+            except HttpResponseException as e:
+                # If an error is received that is due to an unrecognised endpoint,
+                # fallback to the unstable endpoint. Otherwise consider it a
+                # legitmate error and raise.
+                if not self._is_unknown_endpoint(e):
+                    raise
+
+                logger.debug(
+                    "Couldn't fetch room hierarchy with the v1 API, falling back to the unstable API"
+                )
+
+                res = await self.transport_layer.get_room_hierarchy_unstable(
+                    destination=destination,
+                    room_id=room_id,
+                    suggested_only=suggested_only,
+                )
 
             room = res.get("room")
             if not isinstance(room, dict):
@@ -1449,6 +1466,10 @@ class FederationClient(FederationBase):
             if e.code != 502:
                 raise
 
+            logger.debug(
+                "Couldn't fetch room hierarchy, falling back to the spaces API"
+            )
+
             # Fallback to the old federation API and translate the results if
             # no servers implement the new API.
             #
@@ -1496,6 +1517,83 @@ class FederationClient(FederationBase):
         self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
         return result
 
+    async def timestamp_to_event(
+        self, destination: str, room_id: str, timestamp: int, direction: str
+    ) -> "TimestampToEventResponse":
+        """
+        Calls a remote federating server at `destination` asking for their
+        closest event to the given timestamp in the given direction. Also
+        validates the response to always return the expected keys or raises an
+        error.
+
+        Args:
+            destination: Domain name of the remote homeserver
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            A parsed TimestampToEventResponse including the closest event_id
+            and origin_server_ts
+
+        Raises:
+            Various exceptions when the request fails
+            InvalidResponseError when the response does not have the correct
+            keys or wrong types
+        """
+        remote_response = await self.transport_layer.timestamp_to_event(
+            destination, room_id, timestamp, direction
+        )
+
+        if not isinstance(remote_response, dict):
+            raise InvalidResponseError(
+                "Response must be a JSON dictionary but received %r" % remote_response
+            )
+
+        try:
+            return TimestampToEventResponse.from_json_dict(remote_response)
+        except ValueError as e:
+            raise InvalidResponseError(str(e))
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class TimestampToEventResponse:
+    """Typed response dictionary for the federation /timestamp_to_event endpoint"""
+
+    event_id: str
+    origin_server_ts: int
+
+    # the raw data, including the above keys
+    data: JsonDict
+
+    @classmethod
+    def from_json_dict(cls, d: JsonDict) -> "TimestampToEventResponse":
+        """Parsed response from the federation /timestamp_to_event endpoint
+
+        Args:
+            d: JSON object response to be parsed
+
+        Raises:
+            ValueError if d does not the correct keys or they are the wrong types
+        """
+
+        event_id = d.get("event_id")
+        if not isinstance(event_id, str):
+            raise ValueError(
+                "Invalid response: 'event_id' must be a str but received %r" % event_id
+            )
+
+        origin_server_ts = d.get("origin_server_ts")
+        if not isinstance(origin_server_ts, int):
+            raise ValueError(
+                "Invalid response: 'origin_server_ts' must be a int but received %r"
+                % origin_server_ts
+            )
+
+        return cls(event_id, origin_server_ts, d)
+
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class FederationSpaceSummaryEventResult:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9a8758e9a6..4697a62c18 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,6 @@
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 Matrix.org Federation C.I.C
+# Copyright 2019-2021 Matrix.org Federation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ from typing import (
     Union,
 )
 
+from matrix_common.regex import glob_to_regex
 from prometheus_client import Counter, Gauge, Histogram
 
 from twisted.internet import defer
@@ -66,7 +67,7 @@ from synapse.replication.http.federation import (
 )
 from synapse.storage.databases.main.lock import Lock
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_server_name
@@ -110,6 +111,7 @@ class FederationServer(FederationBase):
         super().__init__(hs)
 
         self.handler = hs.get_federation_handler()
+        self.storage = hs.get_storage()
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
@@ -200,6 +202,48 @@ class FederationServer(FederationBase):
 
         return 200, res
 
+    async def on_timestamp_to_event_request(
+        self, origin: str, room_id: str, timestamp: int, direction: str
+    ) -> Tuple[int, Dict[str, Any]]:
+        """When we receive a federated `/timestamp_to_event` request,
+        handle all of the logic for validating and fetching the event.
+
+        Args:
+            origin: The server we received the event from
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            Tuple indicating the response status code and dictionary response
+            body including `event_id`.
+        """
+        with (await self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            await self.check_server_matches_acl(origin_host, room_id)
+
+            # We only try to fetch data from the local database
+            event_id = await self.store.get_event_id_for_timestamp(
+                room_id, timestamp, direction
+            )
+            if event_id:
+                event = await self.store.get_event(
+                    event_id, allow_none=False, allow_rejected=False
+                )
+
+                return 200, {
+                    "event_id": event_id,
+                    "origin_server_ts": event.origin_server_ts,
+                }
+
+        raise SynapseError(
+            404,
+            "Unable to find event from %s in direction %s" % (timestamp, direction),
+            errcode=Codes.NOT_FOUND,
+        )
+
     async def on_incoming_transaction(
         self,
         origin: str,
@@ -407,7 +451,7 @@ class FederationServer(FederationBase):
         # require callouts to other servers to fetch missing events), but
         # impose a limit to avoid going too crazy with ram/cpu.
 
-        async def process_pdus_for_room(room_id: str):
+        async def process_pdus_for_room(room_id: str) -> None:
             with nested_logging_context(room_id):
                 logger.debug("Processing PDUs for %s", room_id)
 
@@ -504,7 +548,7 @@ class FederationServer(FederationBase):
 
     async def on_state_ids_request(
         self, origin: str, room_id: str, event_id: str
-    ) -> Tuple[int, Dict[str, Any]]:
+    ) -> Tuple[int, JsonDict]:
         if not event_id:
             raise NotImplementedError("Specify an event")
 
@@ -524,7 +568,9 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
-    async def _on_state_ids_request_compute(self, room_id, event_id):
+    async def _on_state_ids_request_compute(
+        self, room_id: str, event_id: str
+    ) -> JsonDict:
         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
         return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
@@ -613,8 +659,11 @@ class FederationServer(FederationBase):
         state = await self.store.get_events(state_ids)
 
         time_now = self._clock.time_msec()
+        event_json = event.get_pdu_json()
         return {
-            "org.matrix.msc3083.v2.event": event.get_pdu_json(),
+            # TODO Remove the unstable prefix when servers have updated.
+            "org.matrix.msc3083.v2.event": event_json,
+            "event": event_json,
             "state": [p.get_pdu_json(time_now) for p in state.values()],
             "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 4fead6ca29..523ab1c51e 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -23,6 +24,7 @@ from typing import Optional, Tuple
 
 from synapse.federation.units import Transaction
 from synapse.logging.utils import log_function
+from synapse.storage.databases.main import DataStore
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
 class TransactionActions:
     """Defines persistence actions that relate to handling Transactions."""
 
-    def __init__(self, datastore):
+    def __init__(self, datastore: DataStore):
         self.store = datastore
 
     @log_function
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 1fbf325fdc..63289a5a33 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -350,7 +351,7 @@ class BaseFederationRow:
     TypeId = ""  # Unique string that ids the type. Must be overridden in sub classes.
 
     @staticmethod
-    def from_data(data):
+    def from_data(data: JsonDict) -> "BaseFederationRow":
         """Parse the data from the federation stream into a row.
 
         Args:
@@ -359,7 +360,7 @@ class BaseFederationRow:
         """
         raise NotImplementedError()
 
-    def to_data(self):
+    def to_data(self) -> JsonDict:
         """Serialize this row to be sent over the federation stream.
 
         Returns:
@@ -368,7 +369,7 @@ class BaseFederationRow:
         """
         raise NotImplementedError()
 
-    def add_to_buffer(self, buff):
+    def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
         """Add this row to the appropriate field in the buffer ready for this
         to be sent over federation.
 
@@ -391,15 +392,15 @@ class PresenceDestinationsRow(
     TypeId = "pd"
 
     @staticmethod
-    def from_data(data):
+    def from_data(data: JsonDict) -> "PresenceDestinationsRow":
         return PresenceDestinationsRow(
             state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
         )
 
-    def to_data(self):
+    def to_data(self) -> JsonDict:
         return {"state": self.state.as_dict(), "dests": self.destinations}
 
-    def add_to_buffer(self, buff):
+    def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
         buff.presence_destinations.append((self.state, self.destinations))
 
 
@@ -417,13 +418,13 @@ class KeyedEduRow(
     TypeId = "k"
 
     @staticmethod
-    def from_data(data):
+    def from_data(data: JsonDict) -> "KeyedEduRow":
         return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
 
-    def to_data(self):
+    def to_data(self) -> JsonDict:
         return {"key": self.key, "edu": self.edu.get_internal_dict()}
 
-    def add_to_buffer(self, buff):
+    def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
         buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
 
 
@@ -433,13 +434,13 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
     TypeId = "e"
 
     @staticmethod
-    def from_data(data):
+    def from_data(data: JsonDict) -> "EduRow":
         return EduRow(Edu(**data))
 
-    def to_data(self):
+    def to_data(self) -> JsonDict:
         return self.edu.get_internal_dict()
 
-    def add_to_buffer(self, buff):
+    def add_to_buffer(self, buff: "ParsedFederationStreamData") -> None:
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index afe35e72b6..391b30fbb5 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -1,5 +1,6 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,7 +15,8 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
+from types import TracebackType
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
 
 import attr
 from prometheus_client import Counter
@@ -213,7 +215,7 @@ class PerDestinationQueue:
         self._pending_edus_keyed[(edu.edu_type, key)] = edu
         self.attempt_new_transaction()
 
-    def send_edu(self, edu) -> None:
+    def send_edu(self, edu: Edu) -> None:
         self._pending_edus.append(edu)
         self.attempt_new_transaction()
 
@@ -701,7 +703,12 @@ class _TransactionQueueManager:
 
         return self._pdus, pending_edus
 
-    async def __aexit__(self, exc_type, exc, tb):
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         if exc_type is not None:
             # Failed to send transaction, so we bail out.
             return
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 10b5aa5af8..9fc4c31c93 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -21,6 +21,7 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    Generator,
     Iterable,
     List,
     Mapping,
@@ -149,6 +150,42 @@ class TransportLayerClient:
         )
 
     @log_function
+    async def timestamp_to_event(
+        self, destination: str, room_id: str, timestamp: int, direction: str
+    ) -> Union[JsonDict, List]:
+        """
+        Calls a remote federating server at `destination` asking for their
+        closest event to the given timestamp in the given direction.
+
+        Args:
+            destination: Domain name of the remote homeserver
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            Response dict received from the remote homeserver.
+
+        Raises:
+            Various exceptions when the request fails
+        """
+        path = _create_path(
+            FEDERATION_UNSTABLE_PREFIX,
+            "/org.matrix.msc3030/timestamp_to_event/%s",
+            room_id,
+        )
+
+        args = {"ts": [str(timestamp)], "dir": [direction]}
+
+        remote_response = await self.client.get_json(
+            destination, path=path, args=args, try_trailing_slash_on_400=True
+        )
+
+        return remote_response
+
+    @log_function
     async def send_transaction(
         self,
         transaction: Transaction,
@@ -199,11 +236,16 @@ class TransportLayerClient:
 
     @log_function
     async def make_query(
-        self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
-    ):
+        self,
+        destination: str,
+        query_type: str,
+        args: dict,
+        retry_on_dns_fail: bool,
+        ignore_backoff: bool = False,
+    ) -> JsonDict:
         path = _create_v1_path("/query/%s", query_type)
 
-        content = await self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args=args,
@@ -212,8 +254,6 @@ class TransportLayerClient:
             ignore_backoff=ignore_backoff,
         )
 
-        return content
-
     @log_function
     async def make_membership_event(
         self,
@@ -1192,10 +1232,24 @@ class TransportLayerClient:
         )
 
     async def get_room_hierarchy(
-        self,
-        destination: str,
-        room_id: str,
-        suggested_only: bool,
+        self, destination: str, room_id: str, suggested_only: bool
+    ) -> JsonDict:
+        """
+        Args:
+            destination: The remote server
+            room_id: The room ID to ask about.
+            suggested_only: if True, only suggested rooms will be returned
+        """
+        path = _create_v1_path("/hierarchy/%s", room_id)
+
+        return await self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"suggested_only": "true" if suggested_only else "false"},
+        )
+
+    async def get_room_hierarchy_unstable(
+        self, destination: str, room_id: str, suggested_only: bool
     ) -> JsonDict:
         """
         Args:
@@ -1267,7 +1321,7 @@ class SendJoinResponse:
 
 
 @ijson.coroutine
-def _event_parser(event_dict: JsonDict):
+def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
     """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
     to add them to a given dictionary.
     """
@@ -1278,7 +1332,9 @@ def _event_parser(event_dict: JsonDict):
 
 
 @ijson.coroutine
-def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
+def _event_list_parser(
+    room_version: RoomVersion, events: List[EventBase]
+) -> Generator[None, JsonDict, None]:
     """Helper function for use with `ijson.items_coro` to parse an array of
     events and add them to the given list.
     """
@@ -1317,15 +1373,26 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
             prefix + "auth_chain.item",
             use_float=True,
         )
-        self._coro_event = ijson.kvitems_coro(
+        # TODO Remove the unstable prefix when servers have updated.
+        #
+        # By re-using the same event dictionary this will cause the parsing of
+        # org.matrix.msc3083.v2.event and event to stomp over each other.
+        # Generally this should be fine.
+        self._coro_unstable_event = ijson.kvitems_coro(
             _event_parser(self._response.event_dict),
             prefix + "org.matrix.msc3083.v2.event",
             use_float=True,
         )
+        self._coro_event = ijson.kvitems_coro(
+            _event_parser(self._response.event_dict),
+            prefix + "event",
+            use_float=True,
+        )
 
     def write(self, data: bytes) -> int:
         self._coro_state.send(data)
         self._coro_auth.send(data)
+        self._coro_unstable_event.send(data)
         self._coro_event.send(data)
 
         return len(data)
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index c32539bf5a..77b936361a 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -22,7 +22,10 @@ from synapse.federation.transport.server._base import (
     Authenticator,
     BaseFederationServlet,
 )
-from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES
+from synapse.federation.transport.server.federation import (
+    FEDERATION_SERVLET_CLASSES,
+    FederationTimestampLookupServlet,
+)
 from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
 from synapse.federation.transport.server.groups_server import (
     GROUP_SERVER_SERVLET_CLASSES,
@@ -299,7 +302,7 @@ def register_servlets(
     authenticator: Authenticator,
     ratelimiter: FederationRateLimiter,
     servlet_groups: Optional[Iterable[str]] = None,
-):
+) -> None:
     """Initialize and register servlet classes.
 
     Will by default register all servlets. For custom behaviour, pass in
@@ -324,6 +327,13 @@ def register_servlets(
             )
 
         for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]:
+            # Only allow the `/timestamp_to_event` servlet if msc3030 is enabled
+            if (
+                servletclass == FederationTimestampLookupServlet
+                and not hs.config.experimental.msc3030_enabled
+            ):
+                continue
+
             servletclass(
                 hs=hs,
                 authenticator=authenticator,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index cef65929c5..dc39e3537b 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -15,10 +15,13 @@
 import functools
 import logging
 import re
+from typing import Any, Awaitable, Callable, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
+from synapse.http.server import HttpServer, ServletCallback
 from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -29,6 +32,7 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.server import HomeServer
+from synapse.types import JsonDict
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import parse_and_validate_server_name
 
@@ -59,9 +63,11 @@ class Authenticator:
             self.replication_client = hs.get_tcp_replication()
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
-    async def authenticate_request(self, request, content):
+    async def authenticate_request(
+        self, request: SynapseRequest, content: Optional[JsonDict]
+    ) -> str:
         now = self._clock.time_msec()
-        json_request = {
+        json_request: JsonDict = {
             "method": request.method.decode("ascii"),
             "uri": request.uri.decode("ascii"),
             "destination": self.server_name,
@@ -114,7 +120,7 @@ class Authenticator:
 
         return origin
 
-    async def _reset_retry_timings(self, origin):
+    async def _reset_retry_timings(self, origin: str) -> None:
         try:
             logger.info("Marking origin %r as up", origin)
             await self.store.set_destination_retry_timings(origin, None, 0, 0)
@@ -133,14 +139,14 @@ class Authenticator:
             logger.exception("Error resetting retry timings on %s", origin)
 
 
-def _parse_auth_header(header_bytes):
+def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str]:
     """Parse an X-Matrix auth header
 
     Args:
-        header_bytes (bytes): header value
+        header_bytes: header value
 
     Returns:
-        Tuple[str, str, str]: origin, key id, signature.
+        origin, key id, signature.
 
     Raises:
         AuthenticationError if the header could not be parsed
@@ -148,9 +154,9 @@ def _parse_auth_header(header_bytes):
     try:
         header_str = header_bytes.decode("utf-8")
         params = header_str.split(" ")[1].split(",")
-        param_dict = dict(kv.split("=") for kv in params)
+        param_dict = {k: v for k, v in (kv.split("=", maxsplit=1) for kv in params)}
 
-        def strip_quotes(value):
+        def strip_quotes(value: str) -> str:
             if value.startswith('"'):
                 return value[1:-1]
             else:
@@ -233,23 +239,25 @@ class BaseFederationServlet:
         self.ratelimiter = ratelimiter
         self.server_name = server_name
 
-    def _wrap(self, func):
+    def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
         authenticator = self.authenticator
         ratelimiter = self.ratelimiter
 
         @functools.wraps(func)
-        async def new_func(request, *args, **kwargs):
+        async def new_func(
+            request: SynapseRequest, *args: Any, **kwargs: str
+        ) -> Optional[Tuple[int, Any]]:
             """A callback which can be passed to HttpServer.RegisterPaths
 
             Args:
-                request (twisted.web.http.Request):
+                request:
                 *args: unused?
-                **kwargs (dict[unicode, unicode]): the dict mapping keys to path
-                    components as specified in the path match regexp.
+                **kwargs: the dict mapping keys to path components as specified
+                    in the path match regexp.
 
             Returns:
-                Tuple[int, object]|None: (response code, response object) as returned by
-                    the callback method. None if the request has already been handled.
+                (response code, response object) as returned by the callback method.
+                None if the request has already been handled.
             """
             content = None
             if request.method in [b"PUT", b"POST"]:
@@ -257,7 +265,9 @@ class BaseFederationServlet:
                 content = parse_json_object_from_request(request)
 
             try:
-                origin = await authenticator.authenticate_request(request, content)
+                origin: Optional[str] = await authenticator.authenticate_request(
+                    request, content
+                )
             except NoAuthenticationError:
                 origin = None
                 if self.REQUIRE_AUTH:
@@ -301,7 +311,7 @@ class BaseFederationServlet:
                                 "client disconnected before we started processing "
                                 "request"
                             )
-                            return -1, None
+                            return None
                         response = await func(
                             origin, content, request.args, *args, **kwargs
                         )
@@ -312,9 +322,9 @@ class BaseFederationServlet:
 
             return response
 
-        return new_func
+        return cast(ServletCallback, new_func)
 
-    def register(self, server):
+    def register(self, server: HttpServer) -> None:
         pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
 
         for method in ("GET", "PUT", "POST"):
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 2fdf6cc99e..77bfd88ad0 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -174,6 +174,46 @@ class FederationBackfillServlet(BaseFederationServerServlet):
         return await self.handler.on_backfill_request(origin, room_id, versions, limit)
 
 
+class FederationTimestampLookupServlet(BaseFederationServerServlet):
+    """
+    API endpoint to fetch the `event_id` of the closest event to the given
+    timestamp (`ts` query parameter) in the given direction (`dir` query
+    parameter).
+
+    Useful for other homeservers when they're unable to find an event locally.
+
+    `ts` is a timestamp in milliseconds where we will find the closest event in
+    the given direction.
+
+    `dir` can be `f` or `b` to indicate forwards and backwards in time from the
+    given timestamp.
+
+    GET /_matrix/federation/unstable/org.matrix.msc3030/timestamp_to_event/<roomID>?ts=<timestamp>&dir=<direction>
+    {
+        "event_id": ...
+    }
+    """
+
+    PATH = "/timestamp_to_event/(?P<room_id>[^/]*)/?"
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3030"
+
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
+        timestamp = parse_integer_from_args(query, "ts", required=True)
+        direction = parse_string_from_args(
+            query, "dir", default="f", allowed_values=["f", "b"], required=True
+        )
+
+        return await self.handler.on_timestamp_to_event_request(
+            origin, room_id, timestamp, direction
+        )
+
+
 class FederationQueryServlet(BaseFederationServerServlet):
     PATH = "/query/(?P<query_type>[^/]*)"
 
@@ -611,7 +651,6 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
 
 
 class FederationRoomHierarchyServlet(BaseFederationServlet):
-    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
     PATH = "/hierarchy/(?P<room_id>[^/]*)"
 
     def __init__(
@@ -637,6 +676,10 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
         )
 
 
+class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet):
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
+
+
 class RoomComplexityServlet(BaseFederationServlet):
     """
     Indicates to other servers how complex (and therefore likely
@@ -680,6 +723,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationStateV1Servlet,
     FederationStateIdsServlet,
     FederationBackfillServlet,
+    FederationTimestampLookupServlet,
     FederationQueryServlet,
     FederationMakeJoinServlet,
     FederationMakeLeaveServlet,
@@ -701,6 +745,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     RoomComplexityServlet,
     FederationSpaceSummaryServlet,
     FederationRoomHierarchyServlet,
+    FederationRoomHierarchyUnstableServlet,
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
 )
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 53f99031b1..a87896e538 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple
 
 from signedjson.sign import sign_json
 
+from twisted.internet.defer import Deferred
+
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import JsonDict, get_domain_from_id
@@ -166,7 +168,7 @@ class GroupAttestionRenewer:
 
         return {}
 
-    def _start_renew_attestations(self) -> None:
+    def _start_renew_attestations(self) -> "Deferred[None]":
         return run_as_background_process("renew_attestations", self._renew_attestations)
 
     async def _renew_attestations(self) -> None:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 60e59d11a0..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
 import unicodedata
 import urllib.parse
 from binascii import crc32
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -38,6 +39,7 @@ import attr
 import bcrypt
 import pymacaroons
 import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.web.server import Request
 
@@ -181,8 +183,11 @@ class LoginTokenAttributes:
 
     user_id = attr.ib(type=str)
 
-    # the SSO Identity Provider that the user authenticated with, to get this token
     auth_provider_id = attr.ib(type=str)
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id = attr.ib(type=Optional[str])
+    """The session ID advertised by the SSO Identity Provider."""
 
 
 class AuthHandler:
@@ -756,53 +761,109 @@ class AuthHandler:
     async def refresh_token(
         self,
         refresh_token: str,
-        valid_until_ms: Optional[int],
-    ) -> Tuple[str, str]:
+        access_token_valid_until_ms: Optional[int],
+        refresh_token_valid_until_ms: Optional[int],
+    ) -> Tuple[str, str, Optional[int]]:
         """
         Consumes a refresh token and generate both a new access token and a new refresh token from it.
 
         The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
 
+        The lifetime of both the access token and refresh token will be capped so that they
+        do not exceed the session's ultimate expiry time, if applicable.
+
         Args:
             refresh_token: The token to consume.
-            valid_until_ms: The expiration timestamp of the new access token.
-
+            access_token_valid_until_ms: The expiration timestamp of the new access token.
+                None if the access token does not expire.
+            refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+                None if the refresh token does not expire.
         Returns:
-            A tuple containing the new access token and refresh token
+            A tuple containing:
+              - the new access token
+              - the new refresh token
+              - the actual expiry time of the access token, which may be earlier than
+                `access_token_valid_until_ms`.
         """
 
         # Verify the token signature first before looking up the token
         if not self._verify_refresh_token(refresh_token):
-            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+            )
 
         existing_token = await self.store.lookup_refresh_token(refresh_token)
         if existing_token is None:
-            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED,
+                "refresh token does not exist",
+                Codes.UNKNOWN_TOKEN,
+            )
 
         if (
             existing_token.has_next_access_token_been_used
             or existing_token.has_next_refresh_token_been_refreshed
         ):
             raise SynapseError(
-                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+                HTTPStatus.FORBIDDEN,
+                "refresh token isn't valid anymore",
+                Codes.FORBIDDEN,
             )
 
+        now_ms = self._clock.time_msec()
+
+        if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "The supplied refresh token has expired",
+                Codes.FORBIDDEN,
+            )
+
+        if existing_token.ultimate_session_expiry_ts is not None:
+            # This session has a bounded lifetime, even across refreshes.
+
+            if access_token_valid_until_ms is not None:
+                access_token_valid_until_ms = min(
+                    access_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+            if refresh_token_valid_until_ms is not None:
+                refresh_token_valid_until_ms = min(
+                    refresh_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+            if existing_token.ultimate_session_expiry_ts < now_ms:
+                raise SynapseError(
+                    HTTPStatus.FORBIDDEN,
+                    "The session has expired and can no longer be refreshed",
+                    Codes.FORBIDDEN,
+                )
+
         (
             new_refresh_token,
             new_refresh_token_id,
-        ) = await self.get_refresh_token_for_user_id(
-            user_id=existing_token.user_id, device_id=existing_token.device_id
+        ) = await self.create_refresh_token_for_user_id(
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            expiry_ts=refresh_token_valid_until_ms,
+            ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
         )
-        access_token = await self.get_access_token_for_user_id(
+        access_token = await self.create_access_token_for_user_id(
             user_id=existing_token.user_id,
             device_id=existing_token.device_id,
-            valid_until_ms=valid_until_ms,
+            valid_until_ms=access_token_valid_until_ms,
             refresh_token_id=new_refresh_token_id,
         )
         await self.store.replace_refresh_token(
             existing_token.token_id, new_refresh_token_id
         )
-        return access_token, new_refresh_token
+        return access_token, new_refresh_token, access_token_valid_until_ms
 
     def _verify_refresh_token(self, token: str) -> bool:
         """
@@ -832,10 +893,12 @@ class AuthHandler:
 
         return True
 
-    async def get_refresh_token_for_user_id(
+    async def create_refresh_token_for_user_id(
         self,
         user_id: str,
         device_id: str,
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> Tuple[str, int]:
         """
         Creates a new refresh token for the user with the given user ID.
@@ -843,6 +906,13 @@ class AuthHandler:
         Args:
             user_id: canonical user ID
             device_id: the device ID to associate with the token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
 
         Returns:
             The newly created refresh token and its ID in the database
@@ -852,10 +922,12 @@ class AuthHandler:
             user_id=user_id,
             token=refresh_token,
             device_id=device_id,
+            expiry_ts=expiry_ts,
+            ultimate_session_expiry_ts=ultimate_session_expiry_ts,
         )
         return refresh_token, refresh_token_id
 
-    async def get_access_token_for_user_id(
+    async def create_access_token_for_user_id(
         self,
         user_id: str,
         device_id: Optional[str],
@@ -1582,6 +1654,7 @@ class AuthHandler:
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1597,6 +1670,7 @@ class AuthHandler:
                 during successful login. Must be JSON serializable.
             new_user: True if we should use wording appropriate to a user who has just
                 registered.
+            auth_provider_session_id: The session ID from the SSO IdP received during login.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1617,6 +1691,7 @@ class AuthHandler:
             extra_attributes,
             new_user=new_user,
             user_profile_data=profile,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     def _complete_sso_login(
@@ -1628,6 +1703,7 @@ class AuthHandler:
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         The synchronous portion of complete_sso_login.
@@ -1649,7 +1725,9 @@ class AuthHandler:
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id, auth_provider_id=auth_provider_id
+            registered_user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1754,6 +1832,7 @@ class MacaroonGenerator:
         self,
         user_id: str,
         auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
         duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1762,6 +1841,10 @@ class MacaroonGenerator:
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                "auth_provider_session_id = %s" % (auth_provider_session_id,)
+            )
         return macaroon.serialize()
 
     def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1783,15 +1866,28 @@ class MacaroonGenerator:
         user_id = get_value_from_macaroon(macaroon, "user_id")
         auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
 
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
         v.satisfy_exact("type = login")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
         satisfy_expiry(v, self.hs.get_clock().time_msec)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
 
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1828,13 +1924,6 @@ def load_single_legacy_password_auth_provider(
         logger.error("Error while initializing %r: %s", module, e)
         raise
 
-    # The known hooks. If a module implements a method who's name appears in this set
-    # we'll want to register it
-    password_auth_provider_methods = {
-        "check_3pid_auth",
-        "on_logged_out",
-    }
-
     # All methods that the module provides should be async, but this wasn't enforced
     # in the old module system, so we wrap them if needed
     def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
@@ -1919,11 +2008,14 @@ def load_single_legacy_password_auth_provider(
 
         return run
 
-    # populate hooks with the implemented methods, wrapped with async_wrapper
-    hooks = {
-        hook: async_wrapper(getattr(provider, hook, None))
-        for hook in password_auth_provider_methods
-    }
+    # If the module has these methods implemented, then we pull them out
+    # and register them as hooks.
+    check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper(
+        getattr(provider, "check_3pid_auth", None)
+    )
+    on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper(
+        getattr(provider, "on_logged_out", None)
+    )
 
     supported_login_types = {}
     # call get_supported_login_types and add that to the dict
@@ -1950,7 +2042,11 @@ def load_single_legacy_password_auth_provider(
         # need to use a tuple here for ("password",) not a list since lists aren't hashable
         auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
 
-    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+    api.register_password_auth_provider_callbacks(
+        check_3pid_auth=check_3pid_auth_hook,
+        on_logged_out=on_logged_out_hook,
+        auth_checkers=auth_checkers,
+    )
 
 
 CHECK_3PID_AUTH_CALLBACK = Callable[
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68b446eb66..82ee11e921 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
         user_id: str,
         device_id: Optional[str],
         initial_device_display_name: Optional[str] = None,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> str:
         """
         If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id:  @user:id
             device_id: device id supplied by client
             initial_device_display_name: device display name from client
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from the SSO IdP.
         Returns:
             device id (generated if none was supplied)
         """
@@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=new_device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1f64534a8a..32b0254c5f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -122,8 +122,7 @@ class EventStreamHandler:
                 events,
                 time_now,
                 as_client_event=as_client_event,
-                # We don't bundle "live" events, as otherwise clients
-                # will end up double counting annotations.
+                # Don't bundle aggregations as this is a deprecated API.
                 bundle_aggregations=False,
             )
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3112cc88b1..1ea837d082 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,6 +68,37 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
+    """Get joined domains from state
+
+    Args:
+        state: State map from type/state key to event.
+
+    Returns:
+        Returns a list of servers with the lowest depth of their joins.
+            Sorted by lowest depth first.
+    """
+    joined_users = [
+        (state_key, int(event.depth))
+        for (e_type, state_key), event in state.items()
+        if e_type == EventTypes.Member and event.membership == Membership.JOIN
+    ]
+
+    joined_domains: Dict[str, int] = {}
+    for u, d in joined_users:
+        try:
+            dom = get_domain_from_id(u)
+            old_d = joined_domains.get(dom)
+            if old_d:
+                joined_domains[dom] = min(d, old_d)
+            else:
+                joined_domains[dom] = d
+        except Exception:
+            pass
+
+    return sorted(joined_domains.items(), key=lambda d: d[1])
+
+
 class FederationHandler:
     """Handles general incoming federation requests
 
@@ -268,36 +299,6 @@ class FederationHandler:
 
         curr_state = await self.state_handler.get_current_state(room_id)
 
-        def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-            """Get joined domains from state
-
-            Args:
-                state: State map from type/state key to event.
-
-            Returns:
-                Returns a list of servers with the lowest depth of their joins.
-                 Sorted by lowest depth first.
-            """
-            joined_users = [
-                (state_key, int(event.depth))
-                for (e_type, state_key), event in state.items()
-                if e_type == EventTypes.Member and event.membership == Membership.JOIN
-            ]
-
-            joined_domains: Dict[str, int] = {}
-            for u, d in joined_users:
-                try:
-                    dom = get_domain_from_id(u)
-                    old_d = joined_domains.get(dom)
-                    if old_d:
-                        joined_domains[dom] = min(d, old_d)
-                    else:
-                        joined_domains[dom] = d
-                except Exception:
-                    pass
-
-            return sorted(joined_domains.items(), key=lambda d: d[1])
-
         curr_domains = get_domains_from_state(curr_state)
 
         likely_domains = [
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 3dbe611f95..c83eaea359 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -464,15 +464,6 @@ class IdentityHandler:
         if next_link:
             params["next_link"] = next_link
 
-        if self.hs.config.email.using_identity_server_from_trusted_list:
-            # Warn that a deprecated config option is in use
-            logger.warning(
-                'The config option "trust_identity_server_for_password_resets" '
-                'has been replaced by "account_threepid_delegate". '
-                "Please consult the sample config at docs/sample_config.yaml for "
-                "details and update your config file."
-            )
-
         try:
             data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
@@ -517,15 +508,6 @@ class IdentityHandler:
         if next_link:
             params["next_link"] = next_link
 
-        if self.hs.config.email.using_identity_server_from_trusted_list:
-            # Warn that a deprecated config option is in use
-            logger.warning(
-                'The config option "trust_identity_server_for_password_resets" '
-                'has been replaced by "account_threepid_delegate". '
-                "Please consult the sample config at docs/sample_config.yaml for "
-                "details and update your config file."
-            )
-
         try:
             data = await self.http_client.post_json_get_json(
                 id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d4e4556155..9cd21e7f2b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -165,7 +165,11 @@ class InitialSyncHandler:
 
                 invite_event = await self.store.get_event(event.event_id)
                 d["invite"] = await self._event_serializer.serialize_event(
-                    invite_event, time_now, as_client_event
+                    invite_event,
+                    time_now,
+                    # Don't bundle aggregations as this is a deprecated API.
+                    bundle_aggregations=False,
+                    as_client_event=as_client_event,
                 )
 
             rooms_ret.append(d)
@@ -216,7 +220,11 @@ class InitialSyncHandler:
                 d["messages"] = {
                     "chunk": (
                         await self._event_serializer.serialize_events(
-                            messages, time_now=time_now, as_client_event=as_client_event
+                            messages,
+                            time_now=time_now,
+                            # Don't bundle aggregations as this is a deprecated API.
+                            bundle_aggregations=False,
+                            as_client_event=as_client_event,
                         )
                     ),
                     "start": await start_token.to_string(self.store),
@@ -226,6 +234,8 @@ class InitialSyncHandler:
                 d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
+                    # Don't bundle aggregations as this is a deprecated API.
+                    bundle_aggregations=False,
                     as_client_event=as_client_event,
                 )
 
@@ -366,14 +376,18 @@ class InitialSyncHandler:
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    await self._event_serializer.serialize_events(messages, time_now)
+                    # Don't bundle aggregations as this is a deprecated API.
+                    await self._event_serializer.serialize_events(
+                        messages, time_now, bundle_aggregations=False
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
             },
             "state": (
+                # Don't bundle aggregations as this is a deprecated API.
                 await self._event_serializer.serialize_events(
-                    room_state.values(), time_now
+                    room_state.values(), time_now, bundle_aggregations=False
                 )
             ),
             "presence": [],
@@ -392,8 +406,9 @@ class InitialSyncHandler:
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
+        # Don't bundle aggregations as this is a deprecated API.
         state = await self._event_serializer.serialize_events(
-            current_state.values(), time_now
+            current_state.values(), time_now, bundle_aggregations=False
         )
 
         now_token = self.hs.get_event_sources().get_current_token()
@@ -467,7 +482,10 @@ class InitialSyncHandler:
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    await self._event_serializer.serialize_events(messages, time_now)
+                    # Don't bundle aggregations as this is a deprecated API.
+                    await self._event_serializer.serialize_events(
+                        messages, time_now, bundle_aggregations=False
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4c2a6ab7a..87f671708c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -247,13 +247,7 @@ class MessageHandler:
                 room_state = room_state_events[membership_event_id]
 
         now = self.clock.time_msec()
-        events = await self._event_serializer.serialize_events(
-            room_state.values(),
-            now,
-            # We don't bother bundling aggregations in when asked for state
-            # events, as clients won't use them.
-            bundle_aggregations=False,
-        )
+        events = await self._event_serializer.serialize_events(room_state.values(), now)
         return events
 
     async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
@@ -1001,13 +995,52 @@ class EventCreationHandler:
             )
 
         self.validator.validate_new(event, self.config)
+        await self._validate_event_relation(event)
+        logger.debug("Created event %s", event.event_id)
+
+        return event, context
+
+    async def _validate_event_relation(self, event: EventBase) -> None:
+        """
+        Ensure the relation data on a new event is not bogus.
+
+        Args:
+            event: The event being created.
+
+        Raises:
+            SynapseError if the event is invalid.
+        """
+
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            return
+
+        relation_type = relation.get("rel_type")
+        if not relation_type:
+            return
+
+        # Ensure the parent is real.
+        relates_to = relation.get("event_id")
+        if not relates_to:
+            return
+
+        parent_event = await self.store.get_event(relates_to, allow_none=True)
+        if parent_event:
+            # And in the same room.
+            if parent_event.room_id != event.room_id:
+                raise SynapseError(400, "Relations must be in the same room")
+
+        else:
+            # There must be some reason that the client knows the event exists,
+            # see if there are existing relations. If so, assume everything is fine.
+            if not await self.store.event_is_target_of_relation(relates_to):
+                # Otherwise, the client can't know about the parent event!
+                raise SynapseError(400, "Can't send relation to unknown event")
 
         # If this event is an annotation then we check that that the sender
         # can't annotate the same way twice (e.g. stops users from liking an
         # event multiple times).
-        relation = event.content.get("m.relates_to", {})
-        if relation.get("rel_type") == RelationTypes.ANNOTATION:
-            relates_to = relation["event_id"]
+        if relation_type == RelationTypes.ANNOTATION:
             aggregation_key = relation["key"]
 
             already_exists = await self.store.has_user_annotated_event(
@@ -1016,9 +1049,12 @@ class EventCreationHandler:
             if already_exists:
                 raise SynapseError(400, "Can't send same reaction twice")
 
-        logger.debug("Created event %s", event.event_id)
-
-        return event, context
+        # Don't attempt to start a thread if the parent event is a relation.
+        elif relation_type == RelationTypes.THREAD:
+            if await self.store.event_includes_relation(relates_to):
+                raise SynapseError(
+                    400, "Cannot start threads from an event with a relation"
+                )
 
     @measure_func("handle_new_client_event")
     async def handle_new_client_event(
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
 from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
 from jinja2 import Environment, Template
 from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
         for idp_id, p in self._providers.items():
             try:
                 await p.load_metadata()
-                await p.load_jwks()
+                if not p._uses_userinfo:
+                    await p.load_jwks()
             except Exception as e:
                 raise Exception(
                     "Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
         return await self._jwks.get()
 
     async def _load_jwks(self) -> JWKS:
-        if self._uses_userinfo:
-            # We're not using jwt signing, return an empty jwk set
-            return {"keys": []}
-
         metadata = await self.load_metadata()
 
         # Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
         Args:
@@ -673,7 +670,7 @@ class OidcProvider:
                 request. This value should match the one inside the token.
 
         Returns:
-            An object representing the user.
+            The decoded claims in the ID token.
         """
         metadata = await self.load_metadata()
         claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
             # If we got an `access_token`, there should be an `at_hash` claim
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
-            claims_cls = CodeIDToken
-        else:
-            claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -713,7 +707,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -721,7 +715,8 @@ class OidcProvider:
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
         claims.validate(leeway=120)  # allows 2 min of clock skew
-        return UserInfo(claims)
+
+        return claims
 
     async def handle_redirect_request(
         self,
@@ -837,8 +832,22 @@ class OidcProvider:
 
         logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
-        # Now that we have a token, get the userinfo, either by decoding the
-        # `id_token` or by fetching the `userinfo_endpoint`.
+        # If there is an id_token, it should be validated, regardless of the
+        # userinfo endpoint is used or not.
+        if token.get("id_token") is not None:
+            try:
+                id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+                sid = id_token.get("sid")
+            except Exception as e:
+                logger.exception("Invalid id_token")
+                self._sso_handler.render_error(request, "invalid_token", str(e))
+                return
+        else:
+            id_token = None
+            sid = None
+
+        # Now that we have a token, get the userinfo either from the `id_token`
+        # claims or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
             try:
                 userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
                 logger.exception("Could not fetch userinfo")
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
+        elif id_token is not None:
+            userinfo = UserInfo(id_token)
         else:
-            try:
-                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
-            except Exception as e:
-                logger.exception("Invalid id_token")
-                self._sso_handler.render_error(request, "invalid_token", str(e))
-                return
+            logger.error("Missing id_token in token response")
+            self._sso_handler.render_error(
+                request, "invalid_token", "Missing id_token in token response"
+            )
+            return
 
         # first check if we're doing a UIA
         if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, session_data.client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url, sid
             )
         except MappingException as e:
             logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
         token: Token,
         request: SynapseRequest,
         client_redirect_url: str,
+        sid: Optional[str],
     ) -> None:
         """Given a UserInfo response, complete the login flow
 
@@ -1008,6 +1019,7 @@ class OidcProvider:
             oidc_response_to_user_attributes,
             grandfather_existing_users,
             extra_attributes,
+            auth_provider_session_id=sid,
         )
 
     def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a0e6a01775..f08a516a75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -1,4 +1,5 @@
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -116,7 +117,13 @@ class RegistrationHandler:
             self.pusher_pool = hs.get_pusherpool()
 
         self.session_lifetime = hs.config.registration.session_lifetime
-        self.access_token_lifetime = hs.config.registration.access_token_lifetime
+        self.nonrefreshable_access_token_lifetime = (
+            hs.config.registration.nonrefreshable_access_token_lifetime
+        )
+        self.refreshable_access_token_lifetime = (
+            hs.config.registration.refreshable_access_token_lifetime
+        )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
         init_counters_for_auth_provider("")
 
@@ -739,6 +746,7 @@ class RegistrationHandler:
         is_appservice_ghost: bool = False,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> Tuple[str, str, Optional[int], Optional[str]]:
         """Register a device for a user and generate an access token.
 
@@ -749,9 +757,9 @@ class RegistrationHandler:
             device_id: The device ID to check, or None to generate a new one.
             initial_display_name: An optional display name for the device.
             is_guest: Whether this is a guest account
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: Whether it should also issue a refresh token
+            auth_provider_session_id: The session ID received during login from the SSO IdP.
         Returns:
             Tuple of device ID, access token, access token expiration time and refresh token
         """
@@ -762,6 +770,8 @@ class RegistrationHandler:
             is_guest=is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         login_counter.labels(
@@ -784,6 +794,8 @@ class RegistrationHandler:
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
         should_issue_refresh_token: bool = False,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginDict:
         """Helper for register_device
 
@@ -791,38 +803,86 @@ class RegistrationHandler:
         class and RegisterDeviceReplicationServlet.
         """
         assert not self.hs.config.worker.worker_app
-        valid_until_ms = None
+        now_ms = self.clock.time_msec()
+        access_token_expiry = None
         if self.session_lifetime is not None:
             if is_guest:
                 raise Exception(
                     "session_lifetime is not currently implemented for guest access"
                 )
-            valid_until_ms = self.clock.time_msec() + self.session_lifetime
+            access_token_expiry = now_ms + self.session_lifetime
+
+        if self.nonrefreshable_access_token_lifetime is not None:
+            if access_token_expiry is not None:
+                # Don't allow the non-refreshable access token to outlive the
+                # session.
+                access_token_expiry = min(
+                    now_ms + self.nonrefreshable_access_token_lifetime,
+                    access_token_expiry,
+                )
+            else:
+                access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
 
         refresh_token = None
         refresh_token_id = None
 
         registered_device_id = await self.device_handler.check_device_registered(
-            user_id, device_id, initial_display_name
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
         if is_guest:
-            assert valid_until_ms is None
+            assert access_token_expiry is None
             access_token = self.macaroon_gen.generate_guest_access_token(user_id)
         else:
             if should_issue_refresh_token:
+                # A refreshable access token lifetime must be configured
+                # since we're told to issue a refresh token (the caller checks
+                # that this value is set before setting this flag).
+                assert self.refreshable_access_token_lifetime is not None
+
+                # Set the expiry time of the refreshable access token
+                access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+                # Set the refresh token expiry time (if configured)
+                refresh_token_expiry = None
+                if self.refresh_token_lifetime is not None:
+                    refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+                # Set an ultimate session expiry time (if configured)
+                ultimate_session_expiry_ts = None
+                if self.session_lifetime is not None:
+                    ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+                    # Also ensure that the issued tokens don't outlive the
+                    # session.
+                    # (It would be weird to configure a homeserver with a shorter
+                    # session lifetime than token lifetime, but may as well handle
+                    # it.)
+                    access_token_expiry = min(
+                        access_token_expiry, ultimate_session_expiry_ts
+                    )
+                    if refresh_token_expiry is not None:
+                        refresh_token_expiry = min(
+                            refresh_token_expiry, ultimate_session_expiry_ts
+                        )
+
                 (
                     refresh_token,
                     refresh_token_id,
-                ) = await self._auth_handler.get_refresh_token_for_user_id(
+                ) = await self._auth_handler.create_refresh_token_for_user_id(
                     user_id,
                     device_id=registered_device_id,
+                    expiry_ts=refresh_token_expiry,
+                    ultimate_session_expiry_ts=ultimate_session_expiry_ts,
                 )
-                valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
 
-            access_token = await self._auth_handler.get_access_token_for_user_id(
+            access_token = await self._auth_handler.create_access_token_for_user_id(
                 user_id,
                 device_id=registered_device_id,
-                valid_until_ms=valid_until_ms,
+                valid_until_ms=access_token_expiry,
                 is_appservice_ghost=is_appservice_ghost,
                 refresh_token_id=refresh_token_id,
             )
@@ -830,7 +890,7 @@ class RegistrationHandler:
         return {
             "device_id": registered_device_id,
             "access_token": access_token,
-            "valid_until_ms": valid_until_ms,
+            "valid_until_ms": access_token_expiry,
             "refresh_token": refresh_token,
         }
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f9a099c4f3..2bcdf32dcc 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -46,6 +46,7 @@ from synapse.api.constants import (
 from synapse.api.errors import (
     AuthError,
     Codes,
+    HttpResponseException,
     LimitExceededError,
     NotFoundError,
     StoreError,
@@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
@@ -775,8 +778,11 @@ class RoomCreationHandler:
             raise SynapseError(403, "Room visibility value not allowed.")
 
         if is_public:
+            room_aliases = []
+            if room_alias:
+                room_aliases.append(room_alias.to_string())
             if not self.config.roomdirectory.is_publishing_room_allowed(
-                user_id, room_id, room_alias
+                user_id, room_id, room_aliases
             ):
                 # Let's just return a generic message, as there may be all sorts of
                 # reasons why we said no. TODO: Allow configurable error messages
@@ -1217,6 +1223,147 @@ class RoomContextHandler:
         return results
 
 
+class TimestampLookupHandler:
+    def __init__(self, hs: "HomeServer"):
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.state_handler = hs.get_state_handler()
+        self.federation_client = hs.get_federation_client()
+
+    async def get_event_for_timestamp(
+        self,
+        requester: Requester,
+        room_id: str,
+        timestamp: int,
+        direction: str,
+    ) -> Tuple[str, int]:
+        """Find the closest event to the given timestamp in the given direction.
+        If we can't find an event locally or the event we have locally is next to a gap,
+        it will ask other federated homeservers for an event.
+
+        Args:
+            requester: The user making the request according to the access token
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            A tuple containing the `event_id` closest to the given timestamp in
+            the given direction and the `origin_server_ts`.
+
+        Raises:
+            SynapseError if unable to find any event locally in the given direction
+        """
+
+        local_event_id = await self.store.get_event_id_for_timestamp(
+            room_id, timestamp, direction
+        )
+        logger.debug(
+            "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
+            local_event_id,
+            timestamp,
+        )
+
+        # Check for gaps in the history where events could be hiding in between
+        # the timestamp given and the event we were able to find locally
+        is_event_next_to_backward_gap = False
+        is_event_next_to_forward_gap = False
+        if local_event_id:
+            local_event = await self.store.get_event(
+                local_event_id, allow_none=False, allow_rejected=False
+            )
+
+            if direction == "f":
+                # We only need to check for a backward gap if we're looking forwards
+                # to ensure there is nothing in between.
+                is_event_next_to_backward_gap = (
+                    await self.store.is_event_next_to_backward_gap(local_event)
+                )
+            elif direction == "b":
+                # We only need to check for a forward gap if we're looking backwards
+                # to ensure there is nothing in between
+                is_event_next_to_forward_gap = (
+                    await self.store.is_event_next_to_forward_gap(local_event)
+                )
+
+        # If we found a gap, we should probably ask another homeserver first
+        # about more history in between
+        if (
+            not local_event_id
+            or is_event_next_to_backward_gap
+            or is_event_next_to_forward_gap
+        ):
+            logger.debug(
+                "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
+                local_event_id,
+                timestamp,
+            )
+
+            # Find other homeservers from the given state in the room
+            curr_state = await self.state_handler.get_current_state(room_id)
+            curr_domains = get_domains_from_state(curr_state)
+            likely_domains = [
+                domain for domain, depth in curr_domains if domain != self.server_name
+            ]
+
+            # Loop through each homeserver candidate until we get a succesful response
+            for domain in likely_domains:
+                try:
+                    remote_response = await self.federation_client.timestamp_to_event(
+                        domain, room_id, timestamp, direction
+                    )
+                    logger.debug(
+                        "get_event_for_timestamp: response from domain(%s)=%s",
+                        domain,
+                        remote_response,
+                    )
+
+                    # TODO: Do we want to persist this as an extremity?
+                    # TODO: I think ideally, we would try to backfill from
+                    # this event and run this whole
+                    # `get_event_for_timestamp` function again to make sure
+                    # they didn't give us an event from their gappy history.
+                    remote_event_id = remote_response.event_id
+                    origin_server_ts = remote_response.origin_server_ts
+
+                    # Only return the remote event if it's closer than the local event
+                    if not local_event or (
+                        abs(origin_server_ts - timestamp)
+                        < abs(local_event.origin_server_ts - timestamp)
+                    ):
+                        return remote_event_id, origin_server_ts
+                except (HttpResponseException, InvalidResponseError) as ex:
+                    # Let's not put a high priority on some other homeserver
+                    # failing to respond or giving a random response
+                    logger.debug(
+                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        domain,
+                        type(ex).__name__,
+                        ex,
+                        ex.args,
+                    )
+                except Exception as ex:
+                    # But we do want to see some exceptions in our code
+                    logger.warning(
+                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        domain,
+                        type(ex).__name__,
+                        ex,
+                        ex.args,
+                    )
+
+        if not local_event_id:
+            raise SynapseError(
+                404,
+                "Unable to find event from %s in direction %s" % (timestamp, direction),
+                errcode=Codes.NOT_FOUND,
+            )
+
+        return local_event_id, local_event.origin_server_ts
+
+
 class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 0723286383..f880aa93d2 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -221,6 +221,7 @@ class RoomBatchHandler:
                     action=membership,
                     content=event_dict["content"],
                     outlier=True,
+                    historical=True,
                     prev_event_ids=[prev_event_id_for_state_chain],
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
@@ -240,6 +241,7 @@ class RoomBatchHandler:
                     ),
                     event_dict,
                     outlier=True,
+                    historical=True,
                     prev_event_ids=[prev_event_id_for_state_chain],
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 08244b690d..a6dbff637f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -268,6 +268,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         content: Optional[dict] = None,
         require_consent: bool = True,
         outlier: bool = False,
+        historical: bool = False,
     ) -> Tuple[str, int]:
         """
         Internal membership update function to get an existing event or create
@@ -293,6 +294,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
 
         Returns:
             Tuple of event ID and stream ordering position
@@ -337,6 +341,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             auth_event_ids=auth_event_ids,
             require_consent=require_consent,
             outlier=outlier,
+            historical=historical,
         )
 
         prev_state_ids = await context.get_prev_state_ids()
@@ -433,6 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         new_room: bool = False,
         require_consent: bool = True,
         outlier: bool = False,
+        historical: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         auth_event_ids: Optional[List[str]] = None,
     ) -> Tuple[str, int]:
@@ -454,6 +460,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             prev_event_ids: The event IDs to use as the prev events
             auth_event_ids:
                 The event ids to use as the auth_events for the new event.
@@ -487,6 +496,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 new_room=new_room,
                 require_consent=require_consent,
                 outlier=outlier,
+                historical=historical,
                 prev_event_ids=prev_event_ids,
                 auth_event_ids=auth_event_ids,
             )
@@ -507,6 +517,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         new_room: bool = False,
         require_consent: bool = True,
         outlier: bool = False,
+        historical: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         auth_event_ids: Optional[List[str]] = None,
     ) -> Tuple[str, int]:
@@ -530,6 +541,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             prev_event_ids: The event IDs to use as the prev events
             auth_event_ids:
                 The event ids to use as the auth_events for the new event.
@@ -657,6 +671,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 content=content,
                 require_consent=require_consent,
                 outlier=outlier,
+                historical=historical,
             )
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index fb26ee7ad7..b2cfe537df 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -93,11 +94,14 @@ class RoomSummaryHandler:
         self._event_serializer = hs.get_event_client_serializer()
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
+        self._ratelimiter = Ratelimiter(
+            store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+        )
 
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
         self._pagination_response_cache: ResponseCache[
-            Tuple[str, bool, Optional[int], Optional[int], Optional[str]]
+            Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]]
         ] = ResponseCache(
             hs.get_clock(),
             "get_room_hierarchy",
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
 
     async def get_room_hierarchy(
         self,
-        requester: str,
+        requester: Requester,
         requested_room_id: str,
         suggested_only: bool = False,
         max_depth: Optional[int] = None,
@@ -276,15 +280,24 @@ class RoomSummaryHandler:
         Returns:
             The JSON hierarchy dictionary.
         """
+        await self._ratelimiter.ratelimit(requester)
+
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
         #
         # This is due to the pagination process mutating internal state, attempting
         # to process multiple requests for the same page will result in errors.
         return await self._pagination_response_cache.wrap(
-            (requested_room_id, suggested_only, max_depth, limit, from_token),
+            (
+                requester.user.to_string(),
+                requested_room_id,
+                suggested_only,
+                max_depth,
+                limit,
+                from_token,
+            ),
             self._get_room_hierarchy,
-            requester,
+            requester.user.to_string(),
             requested_room_id,
             suggested_only,
             max_depth,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 49fde01cf0..65c27bc64a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,6 +365,7 @@ class SsoHandler:
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
         grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
         extra_login_attributes: Optional[JsonDict] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ class SsoHandler:
             extra_login_attributes: An optional dictionary of extra
                 attributes to be provided to the client in the login response.
 
+            auth_provider_session_id: An optional session ID from the IdP.
+
         Raises:
             MappingException if there was a problem mapping the response to a user.
             RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ class SsoHandler:
             client_redirect_url,
             extra_login_attributes,
             new_user=new_user,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     async def _call_attribute_mapper(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 891435c14d..53d4627147 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -334,6 +334,19 @@ class SyncHandler:
         full_state: bool,
         cache_context: ResponseCacheContext[SyncRequestKey],
     ) -> SyncResult:
+        """The start of the machinery that produces a /sync response.
+
+        See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
+
+        This method does high-level bookkeeping:
+        - tracking the kind of sync in the logging context
+        - deleting any to_device messages whose delivery has been acknowledged.
+        - deciding if we should dispatch an instant or delayed response
+        - marking the sync as being lazily loaded, if appropriate
+
+        Computing the body of the response begins in the next method,
+        `current_sync_for_user`.
+        """
         if since_token is None:
             sync_type = "initial_sync"
         elif full_state:
@@ -363,7 +376,7 @@ class SyncHandler:
                 sync_config, since_token, full_state=full_state
             )
         else:
-
+            # Otherwise, we wait for something to happen and report it to the user.
             async def current_sync_callback(
                 before_token: StreamToken, after_token: StreamToken
             ) -> SyncResult:
@@ -402,7 +415,12 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Get the sync for client needed to match what the server has now."""
+        """Generates the response body of a sync result, represented as a SyncResult.
+
+        This is a wrapper around `generate_sync_result` which starts an open tracing
+        span to track the sync. See `generate_sync_result` for the next part of your
+        indoctrination.
+        """
         with start_active_span("current_sync_for_user"):
             log_kv({"since_token": since_token})
             sync_result = await self.generate_sync_result(
@@ -560,7 +578,7 @@ class SyncHandler:
                 # that have happened since `since_key` up to `end_key`, so we
                 # can just use `get_room_events_stream_for_room`.
                 # Otherwise, we want to return the last N events in the room
-                # in toplogical ordering.
+                # in topological ordering.
                 if since_key:
                     events, end_key = await self.store.get_room_events_stream_for_room(
                         room_id,
@@ -1042,7 +1060,18 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Generates a sync result."""
+        """Generates the response body of a sync result.
+
+        This is represented by a `SyncResult` struct, which is built from small pieces
+        using a `SyncResultBuilder`. See also
+            https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+        the `sync_result_builder` is passed as a mutable ("inout") parameter to various
+        helper functions. These retrieve and process the data which forms the sync body,
+        often writing to the `sync_result_builder` to store their output.
+
+        At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
+        instance to signify that the sync calculation is complete.
+        """
         # NB: The now_token gets changed by some of the generate_sync_* methods,
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
@@ -1344,14 +1373,22 @@ class SyncHandler:
     async def _generate_sync_entry_for_account_data(
         self, sync_result_builder: "SyncResultBuilder"
     ) -> Dict[str, Dict[str, JsonDict]]:
-        """Generates the account data portion of the sync response. Populates
-        `sync_result_builder` with the result.
+        """Generates the account data portion of the sync response.
+
+        Account data (called "Client Config" in the spec) can be set either globally
+        or for a specific room. Account data consists of a list of events which
+        accumulate state, much like a room.
+
+        This function retrieves global and per-room account data. The former is written
+        to the given `sync_result_builder`. The latter is returned directly, to be
+        later written to the `sync_result_builder` on a room-by-room basis.
 
         Args:
             sync_result_builder
 
         Returns:
-            A dictionary containing the per room account data.
+            A dictionary whose keys (room ids) map to the per room account data for that
+            room.
         """
         sync_config = sync_result_builder.sync_config
         user_id = sync_result_builder.sync_config.user.to_string()
@@ -1359,7 +1396,7 @@ class SyncHandler:
 
         if since_token and not sync_result_builder.full_state:
             (
-                account_data,
+                global_account_data,
                 account_data_by_room,
             ) = await self.store.get_updated_account_data_for_user(
                 user_id, since_token.account_data_key
@@ -1370,23 +1407,23 @@ class SyncHandler:
             )
 
             if push_rules_changed:
-                account_data["m.push_rules"] = await self.push_rules_for_user(
+                global_account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
             (
-                account_data,
+                global_account_data,
                 account_data_by_room,
             ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
-            account_data["m.push_rules"] = await self.push_rules_for_user(
+            global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
         account_data_for_user = await sync_config.filter_collection.filter_account_data(
             [
                 {"type": account_data_type, "content": content}
-                for account_data_type, content in account_data.items()
+                for account_data_type, content in global_account_data.items()
             ]
         )
 
@@ -1460,15 +1497,22 @@ class SyncHandler:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
+        In the response that reaches the client, rooms are divided into four categories:
+        `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
+        room ids returned by this function.
+
         Args:
             sync_result_builder
             account_data_by_room: Dictionary of per room account data
 
         Returns:
-            Returns a 4-tuple of
-            `(newly_joined_rooms, newly_joined_or_invited_users,
-            newly_left_rooms, newly_left_users)`
+            Returns a 4-tuple whose entries are:
+            - newly_joined_rooms
+            - newly_joined_or_invited_or_knocked_users
+            - newly_left_rooms
+            - newly_left_users
         """
+        # Start by fetching all ephemeral events in rooms we've joined (if required).
         user_id = sync_result_builder.sync_config.user.to_string()
         block_all_room_ephemeral = (
             sync_result_builder.since_token is None
@@ -1590,6 +1634,8 @@ class SyncHandler:
     ) -> bool:
         """Returns whether there may be any new events that should be sent down
         the sync. Returns True if there are.
+
+        Does not modify the `sync_result_builder`.
         """
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
@@ -1597,12 +1643,13 @@ class SyncHandler:
 
         assert since_token
 
-        # Get a list of membership change events that have happened.
-        rooms_changed = await self.store.get_membership_changes_for_user(
+        # Get a list of membership change events that have happened to the user
+        # requesting the sync.
+        membership_changes = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
-        if rooms_changed:
+        if membership_changes:
             return True
 
         stream_id = since_token.room_key.stream
@@ -1614,7 +1661,25 @@ class SyncHandler:
     async def _get_rooms_changed(
         self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
     ) -> _RoomChanges:
-        """Gets the the changes that have happened since the last sync."""
+        """Determine the changes in rooms to report to the user.
+
+        Ideally, we want to report all events whose stream ordering `s` lies in the
+        range `since_token < s <= now_token`, where the two tokens are read from the
+        sync_result_builder.
+
+        If there are too many events in that range to report, things get complicated.
+        In this situation we return a truncated list of the most recent events, and
+        indicate in the response that there is a "gap" of omitted events. Additionally:
+
+        - we include a "state_delta", to describe the changes in state over the gap,
+        - we include all membership events applying to the user making the request,
+          even those in the gap.
+
+        See the spec for the rationale:
+            https://spec.matrix.org/v1.1/client-server-api/#syncing
+
+        The sync_result_builder is not modified by this function.
+        """
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
@@ -1622,21 +1687,36 @@ class SyncHandler:
 
         assert since_token
 
-        # Get a list of membership change events that have happened.
-        rooms_changed = await self.store.get_membership_changes_for_user(
+        # The spec
+        #     https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+        # notes that membership events need special consideration:
+        #
+        # > When a sync is limited, the server MUST return membership events for events
+        # > in the gap (between since and the start of the returned timeline), regardless
+        # > as to whether or not they are redundant.
+        #
+        # We fetch such events here, but we only seem to use them for categorising rooms
+        # as newly joined, newly left, invited or knocked.
+        # TODO: we've already called this function and ran this query in
+        #       _have_rooms_changed. We could keep the results in memory to avoid a
+        #       second query, at the cost of more complicated source code.
+        membership_change_events = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
-        for event in rooms_changed:
+        for event in membership_change_events:
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
-        newly_joined_rooms = []
-        newly_left_rooms = []
-        room_entries = []
-        invited = []
-        knocked = []
+        newly_joined_rooms: List[str] = []
+        newly_left_rooms: List[str] = []
+        room_entries: List[RoomSyncResultBuilder] = []
+        invited: List[InvitedSyncResult] = []
+        knocked: List[KnockedSyncResult] = []
         for room_id, events in mem_change_events_by_room_id.items():
+            # The body of this loop will add this room to at least one of the five lists
+            # above. Things get messy if you've e.g. joined, left, joined then left the
+            # room all in the same sync period.
             logger.debug(
                 "Membership changes in %s: [%s]",
                 room_id,
@@ -1781,7 +1861,9 @@ class SyncHandler:
 
         timeline_limit = sync_config.filter_collection.timeline_limit()
 
-        # Get all events for rooms we're currently joined to.
+        # Get all events since the `from_key` in rooms we're currently joined to.
+        # If there are too many, we get the most recent events only. This leaves
+        # a "gap" in the timeline, as described by the spec for /sync.
         room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
@@ -1842,6 +1924,10 @@ class SyncHandler:
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
+        Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
+
+        This function does not modify the sync_result_builder.
+
         Args:
             sync_result_builder
             ignored_users: Set of users ignored by user.
@@ -1853,16 +1939,9 @@ class SyncHandler:
         now_token = sync_result_builder.now_token
         sync_config = sync_result_builder.sync_config
 
-        membership_list = (
-            Membership.INVITE,
-            Membership.KNOCK,
-            Membership.JOIN,
-            Membership.LEAVE,
-            Membership.BAN,
-        )
-
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
-            user_id=user_id, membership_list=membership_list
+            user_id=user_id,
+            membership_list=Membership.LIST,
         )
 
         room_entries = []
@@ -2212,8 +2291,7 @@ def _calculate_state(
     # to only include membership events for the senders in the timeline.
     # In practice, we can do this by removing them from the p_ids list,
     # which is the list of relevant state we know we have already sent to the client.
-    # see https://github.com/matrix-org/synapse/pull/2970
-    #            /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
+    # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
         p_ids.difference_update(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 22c6174821..1676ebd057 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -90,7 +90,7 @@ class FollowerTypingHandler:
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
     @wrap_as_background_process("typing._handle_timeouts")
-    def _handle_timeouts(self) -> None:
+    async def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 91ba93372c..6dd9b9ad03 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -79,6 +79,35 @@ def parse_integer(
     return parse_integer_from_args(args, name, default, required)
 
 
+@overload
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[int] = None,
+) -> Optional[int]:
+    ...
+
+
+@overload
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    *,
+    required: Literal[True],
+) -> int:
+    ...
+
+
+@overload
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+) -> Optional[int]:
+    ...
+
+
 def parse_integer_from_args(
     args: Mapping[bytes, Sequence[bytes]],
     name: str,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 91ee5c8193..ceef57ad88 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,10 +20,25 @@ import os
 import platform
 import threading
 import time
-from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
-from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
 from prometheus_client.core import (
     REGISTRY,
     CounterMetricFamily,
@@ -32,6 +47,7 @@ from prometheus_client.core import (
 )
 
 from twisted.internet import reactor
+from twisted.internet.base import ReactorBase
 from twisted.python.threadpool import ThreadPool
 
 import synapse
@@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
 
 class RegistryProxy:
     @staticmethod
-    def collect():
+    def collect() -> Iterable[Metric]:
         for metric in REGISTRY.collect():
             if not metric.name.startswith("__"):
                 yield metric
@@ -74,7 +90,7 @@ class LaterGauge:
         ]
     )
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
 
@@ -93,10 +109,10 @@ class LaterGauge:
 
         yield g
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         self._register()
 
-    def _register(self):
+    def _register(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -105,7 +121,12 @@ class LaterGauge:
         all_gauges[self.name] = self
 
 
-class InFlightGauge:
+# `MetricsEntry` only makes sense when it is a `Protocol`,
+# but `Protocol` can't be used as a `TypeVar` bound.
+MetricsEntry = TypeVar("MetricsEntry")
+
+
+class InFlightGauge(Generic[MetricsEntry]):
     """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
     at any given time.
 
@@ -115,14 +136,19 @@ class InFlightGauge:
     callbacks.
 
     Args:
-        name (str)
-        desc (str)
-        labels (list[str])
-        sub_metrics (list[str]): A list of sub metrics that the callbacks
-            will update.
+        name
+        desc
+        labels
+        sub_metrics: A list of sub metrics that the callbacks will update.
     """
 
-    def __init__(self, name, desc, labels, sub_metrics):
+    def __init__(
+        self,
+        name: str,
+        desc: str,
+        labels: Sequence[str],
+        sub_metrics: Sequence[str],
+    ):
         self.name = name
         self.desc = desc
         self.labels = labels
@@ -130,19 +156,25 @@ class InFlightGauge:
 
         # Create a class which have the sub_metrics values as attributes, which
         # default to 0 on initialization. Used to pass to registered callbacks.
-        self._metrics_class = attr.make_class(
+        self._metrics_class: Type[MetricsEntry] = attr.make_class(
             "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
         )
 
         # Counts number of in flight blocks for a given set of label values
-        self._registrations: Dict = {}
+        self._registrations: Dict[
+            Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
+        ] = {}
 
         # Protects access to _registrations
         self._lock = threading.Lock()
 
         self._register_with_collector()
 
-    def register(self, key, callback):
+    def register(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've entered a new block with labels `key`.
 
         `callback` gets called each time the metrics are collected. The same
@@ -158,13 +190,17 @@ class InFlightGauge:
         with self._lock:
             self._registrations.setdefault(key, set()).add(callback)
 
-    def unregister(self, key, callback):
+    def unregister(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've exited a block with labels `key`."""
 
         with self._lock:
             self._registrations.setdefault(key, set()).discard(callback)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         """Called by prometheus client when it reads metrics.
 
         Note: may be called by a separate thread.
@@ -200,7 +236,7 @@ class InFlightGauge:
                 gauge.add_metric(key, getattr(metrics, name))
             yield gauge
 
-    def _register_with_collector(self):
+    def _register_with_collector(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -230,7 +266,7 @@ class GaugeBucketCollector:
         name: str,
         documentation: str,
         buckets: Iterable[float],
-        registry=REGISTRY,
+        registry: CollectorRegistry = REGISTRY,
     ):
         """
         Args:
@@ -257,12 +293,12 @@ class GaugeBucketCollector:
 
         registry.register(self)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         # Don't report metrics unless we've already collected some data
         if self._metric is not None:
             yield self._metric
 
-    def update_data(self, values: Iterable[float]):
+    def update_data(self, values: Iterable[float]) -> None:
         """Update the data to be reported by the metric
 
         The existing data is cleared, and each measurement in the input is assigned
@@ -304,7 +340,7 @@ class GaugeBucketCollector:
 
 
 class CPUMetrics:
-    def __init__(self):
+    def __init__(self) -> None:
         ticks_per_sec = 100
         try:
             # Try and get the system config
@@ -314,7 +350,7 @@ class CPUMetrics:
 
         self.ticks_per_sec = ticks_per_sec
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         if not HAVE_PROC_SELF_STAT:
             return
 
@@ -364,7 +400,7 @@ gc_time = Histogram(
 
 
 class GCCounts:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
         for n, m in enumerate(gc.get_count()):
             cm.add_metric([str(n)], m)
@@ -382,7 +418,7 @@ if not running_on_pypy:
 
 
 class PyPyGCStats:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         # @stats is a pretty-printer object with __str__() returning a nice table,
         # plus some fields that contain data from that table.
@@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
 
 
 class ReactorLastSeenMetric:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily(
             "python_twisted_reactor_last_seen",
             "Seconds since the Twisted reactor was last seen",
@@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
 _last_gc = [0.0, 0.0, 0.0]
 
 
-def runUntilCurrentTimer(reactor, func):
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
     @functools.wraps(func)
-    def f(*args, **kwargs):
+    def f(*args: Any, **kwargs: Any) -> Any:
         now = reactor.seconds()
         num_pending = 0
 
@@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
 
         return ret
 
-    return f
+    return cast(F, f)
 
 
 try:
@@ -677,5 +716,5 @@ __all__ = [
     "start_http_server",
     "LaterGauge",
     "InFlightGauge",
-    "BucketCollector",
+    "GaugeBucketCollector",
 ]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index bb9bcb5592..353d0a63b6 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -25,27 +25,25 @@ import math
 import threading
 from http.server import BaseHTTPRequestHandler, HTTPServer
 from socketserver import ThreadingMixIn
-from typing import Dict, List
+from typing import Any, Dict, List, Type, Union
 from urllib.parse import parse_qs, urlparse
 
-from prometheus_client import REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry
+from prometheus_client.core import Sample
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse.util import caches
 
 CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
 
 
-INF = float("inf")
-MINUS_INF = float("-inf")
-
-
-def floatToGoString(d):
+def floatToGoString(d: Union[int, float]) -> str:
     d = float(d)
-    if d == INF:
+    if d == math.inf:
         return "+Inf"
-    elif d == MINUS_INF:
+    elif d == -math.inf:
         return "-Inf"
     elif math.isnan(d):
         return "NaN"
@@ -60,7 +58,7 @@ def floatToGoString(d):
         return s
 
 
-def sample_line(line, name):
+def sample_line(line: Sample, name: str) -> str:
     if line.labels:
         labelstr = "{{{0}}}".format(
             ",".join(
@@ -82,7 +80,7 @@ def sample_line(line, name):
     return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
-def generate_latest(registry, emit_help=False):
+def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
 
     # Trigger the cache metrics to be rescraped, which updates the common
     # metrics but do not produce metrics themselves
@@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
 
     registry = REGISTRY
 
-    def do_GET(self):
+    def do_GET(self) -> None:
         registry = self.registry
         params = parse_qs(urlparse(self.path).query)
 
@@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
         self.end_headers()
         self.wfile.write(output)
 
-    def log_message(self, format, *args):
+    def log_message(self, format: str, *args: Any) -> None:
         """Log nothing."""
 
     @classmethod
-    def factory(cls, registry):
+    def factory(cls, registry: CollectorRegistry) -> Type:
         """Returns a dynamic MetricsHandler class tied
         to the passed registry.
         """
@@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
     daemon_threads = True
 
 
-def start_http_server(port, addr="", registry=REGISTRY):
+def start_http_server(
+    port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
+) -> None:
     """Starts an HTTP server for prometheus metrics as a daemon thread"""
     CustomMetricsHandler = MetricsHandler.factory(registry)
     httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@@ -252,10 +252,10 @@ class MetricsResource(Resource):
 
     isLeaf = True
 
-    def __init__(self, registry=REGISTRY):
+    def __init__(self, registry: CollectorRegistry = REGISTRY):
         self.registry = registry
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
         response = generate_latest(self.registry)
         request.setHeader(b"Content-Length", str(len(response)))
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 2ab599a334..53c508af91 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,19 +15,37 @@
 import logging
 import threading
 from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Set, Union
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+    Set,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
+from prometheus_client import Metric
 from prometheus_client.core import REGISTRY, Counter, Gauge
 
 from twisted.internet import defer
 
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+    ContextResourceUsage,
+    LoggingContext,
+    PreserveLoggingContext,
+)
 from synapse.logging.opentracing import (
     SynapseTags,
     noop_context_manager,
     start_active_span,
 )
-from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import resource
@@ -116,7 +134,7 @@ class _Collector:
     before they are returned.
     """
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         global _background_processes_active_since_last_scrape
 
         # We swap out the _background_processes set with an empty one so that
@@ -144,12 +162,12 @@ REGISTRY.register(_Collector())
 
 
 class _BackgroundProcess:
-    def __init__(self, desc, ctx):
+    def __init__(self, desc: str, ctx: LoggingContext):
         self.desc = desc
         self._context = ctx
-        self._reported_stats = None
+        self._reported_stats: Optional[ContextResourceUsage] = None
 
-    def update_metrics(self):
+    def update_metrics(self) -> None:
         """Updates the metrics with values from this process."""
         new_stats = self._context.get_resource_usage()
         if self._reported_stats is None:
@@ -169,7 +187,16 @@ class _BackgroundProcess:
         )
 
 
-def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
+R = TypeVar("R")
+
+
+def run_as_background_process(
+    desc: str,
+    func: Callable[..., Awaitable[Optional[R]]],
+    *args: Any,
+    bg_start_span: bool = True,
+    **kwargs: Any,
+) -> "defer.Deferred[Optional[R]]":
     """Run the given function in its own logcontext, with resource metrics
 
     This should be used to wrap processes which are fired off to run in the
@@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         args: positional args for func
         kwargs: keyword args for func
 
-    Returns: Deferred which returns the result of func, but note that it does not
-        follow the synapse logcontext rules.
+    Returns:
+        Deferred which returns the result of func, or `None` if func raises.
+        Note that the returned Deferred does not follow the synapse logcontext
+        rules.
     """
 
-    async def run():
+    async def run() -> Optional[R]:
         with _bg_metrics_lock:
             count = _background_process_counts.get(desc, 0)
             _background_process_counts[desc] = count + 1
@@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
                 else:
                     ctx = noop_context_manager()
                 with ctx:
-                    return await maybe_awaitable(func(*args, **kwargs))
+                    return await func(*args, **kwargs)
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception",
                     desc,
                 )
+                return None
             finally:
                 _background_process_in_flight_count.labels(desc).dec()
 
@@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         return defer.ensureDeferred(run())
 
 
-def wrap_as_background_process(desc):
+F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]])
+
+
+def wrap_as_background_process(desc: str) -> Callable[[F], F]:
     """Decorator that wraps a function that gets called as a background
     process.
 
-    Equivalent of calling the function with `run_as_background_process`
+    Equivalent to calling the function with `run_as_background_process`
     """
 
-    def wrap_as_background_process_inner(func):
+    def wrap_as_background_process_inner(func: F) -> F:
         @wraps(func)
-        def wrap_as_background_process_inner_2(*args, **kwargs):
+        def wrap_as_background_process_inner_2(
+            *args: Any, **kwargs: Any
+        ) -> "defer.Deferred[Optional[R]]":
             return run_as_background_process(desc, func, *args, **kwargs)
 
-        return wrap_as_background_process_inner_2
+        return cast(F, wrap_as_background_process_inner_2)
 
     return wrap_as_background_process_inner
 
@@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
         super().__init__("%s-%s" % (name, instance_id))
         self._proc = _BackgroundProcess(name, self)
 
-    def start(self, rusage: "Optional[resource.struct_rusage]"):
+    def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         """Log context has started running (again)."""
 
         super().start(rusage)
@@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext):
         with _bg_metrics_lock:
             _background_processes_active_since_last_scrape.add(self._proc)
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         """Log context has finished."""
 
         super().__exit__(type, value, traceback)
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 29ab6c0229..98ed9c0829 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -16,14 +16,16 @@ import ctypes
 import logging
 import os
 import re
-from typing import Optional
+from typing import Iterable, Optional
+
+from prometheus_client import Metric
 
 from synapse.metrics import REGISTRY, GaugeMetricFamily
 
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats():
+def _setup_jemalloc_stats() -> None:
     """Checks to see if jemalloc is loaded, and hooks up a collector to record
     statistics exposed by jemalloc.
     """
@@ -135,7 +137,7 @@ def _setup_jemalloc_stats():
     class JemallocCollector:
         """Metrics for internal jemalloc stats."""
 
-        def collect(self):
+        def collect(self) -> Iterable[Metric]:
             _jemalloc_refresh_stats()
 
             g = GaugeMetricFamily(
@@ -185,7 +187,7 @@ def _setup_jemalloc_stats():
     logger.debug("Added jemalloc stats")
 
 
-def setup_jemalloc_stats():
+def setup_jemalloc_stats() -> None:
     """Try to setup jemalloc stats, if jemalloc is loaded."""
 
     try:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ff79bc3c11..6bfb4b8d1b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -35,7 +36,44 @@ from twisted.web.resource import Resource
 
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import (
+    GET_INTERESTED_USERS_CALLBACK,
+    GET_USERS_FOR_STATES_CALLBACK,
+    PresenceRouter,
+)
+from synapse.events.spamcheck import (
+    CHECK_EVENT_FOR_SPAM_CALLBACK,
+    CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
+    CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
+    CHECK_USERNAME_FOR_SPAM_CALLBACK,
+    USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
+    USER_MAY_CREATE_ROOM_CALLBACK,
+    USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK,
+    USER_MAY_INVITE_CALLBACK,
+    USER_MAY_JOIN_ROOM_CALLBACK,
+    USER_MAY_PUBLISH_ROOM_CALLBACK,
+    USER_MAY_SEND_3PID_INVITE_CALLBACK,
+)
+from synapse.events.third_party_rules import (
+    CHECK_EVENT_ALLOWED_CALLBACK,
+    CHECK_THREEPID_CAN_BE_INVITED_CALLBACK,
+    CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK,
+    ON_CREATE_ROOM_CALLBACK,
+    ON_NEW_EVENT_CALLBACK,
+)
+from synapse.handlers.account_validity import (
+    IS_USER_EXPIRED_CALLBACK,
+    ON_LEGACY_ADMIN_REQUEST,
+    ON_LEGACY_RENEW_CALLBACK,
+    ON_LEGACY_SEND_MAIL_CALLBACK,
+    ON_USER_REGISTRATION_CALLBACK,
+)
+from synapse.handlers.auth import (
+    CHECK_3PID_AUTH_CALLBACK,
+    CHECK_AUTH_CALLBACK,
+    ON_LOGGED_OUT_CALLBACK,
+    AuthHandler,
+)
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
     DirectServeHtmlResource,
@@ -44,10 +82,19 @@ from synapse.http.server import (
 )
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.context import (
+    defer_to_thread,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client.login import LoginResponse
 from synapse.storage import DataStore
+from synapse.storage.background_updates import (
+    DEFAULT_BATCH_SIZE_CALLBACK,
+    MIN_BATCH_SIZE_CALLBACK,
+    ON_UPDATE_CALLBACK,
+)
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -67,6 +114,9 @@ if TYPE_CHECKING:
     from synapse.app.generic_worker import GenericWorkerSlavedStore
     from synapse.server import HomeServer
 
+
+T = TypeVar("T")
+
 """
 This package defines the 'stable' API which can be used by extension modules which
 are loaded into Synapse.
@@ -114,7 +164,7 @@ class ModuleApi:
     can register new users etc if necessary.
     """
 
-    def __init__(self, hs: "HomeServer", auth_handler):
+    def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
         self._hs = hs
 
         # TODO: Fix this type hint once the types for the data stores have been ironed
@@ -156,47 +206,139 @@ class ModuleApi:
     #################################################################################
     # The following methods should only be called during the module's initialisation.
 
-    @property
-    def register_spam_checker_callbacks(self):
+    def register_spam_checker_callbacks(
+        self,
+        check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+        user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
+        user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
+        user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
+        user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
+        user_may_create_room_with_invites: Optional[
+            USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK
+        ] = None,
+        user_may_create_room_alias: Optional[
+            USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
+        ] = None,
+        user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
+        check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
+        check_registration_for_spam: Optional[
+            CHECK_REGISTRATION_FOR_SPAM_CALLBACK
+        ] = None,
+        check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
+    ) -> None:
         """Registers callbacks for spam checking capabilities.
 
         Added in Synapse v1.37.0.
         """
-        return self._spam_checker.register_callbacks
+        return self._spam_checker.register_callbacks(
+            check_event_for_spam=check_event_for_spam,
+            user_may_join_room=user_may_join_room,
+            user_may_invite=user_may_invite,
+            user_may_send_3pid_invite=user_may_send_3pid_invite,
+            user_may_create_room=user_may_create_room,
+            user_may_create_room_with_invites=user_may_create_room_with_invites,
+            user_may_create_room_alias=user_may_create_room_alias,
+            user_may_publish_room=user_may_publish_room,
+            check_username_for_spam=check_username_for_spam,
+            check_registration_for_spam=check_registration_for_spam,
+            check_media_file_for_spam=check_media_file_for_spam,
+        )
 
-    @property
-    def register_account_validity_callbacks(self):
+    def register_account_validity_callbacks(
+        self,
+        is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+        on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+        on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+        on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+    ) -> None:
         """Registers callbacks for account validity capabilities.
 
         Added in Synapse v1.39.0.
         """
-        return self._account_validity_handler.register_account_validity_callbacks
+        return self._account_validity_handler.register_account_validity_callbacks(
+            is_user_expired=is_user_expired,
+            on_user_registration=on_user_registration,
+            on_legacy_send_mail=on_legacy_send_mail,
+            on_legacy_renew=on_legacy_renew,
+            on_legacy_admin_request=on_legacy_admin_request,
+        )
 
-    @property
-    def register_third_party_rules_callbacks(self):
+    def register_third_party_rules_callbacks(
+        self,
+        check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+        on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+        check_threepid_can_be_invited: Optional[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = None,
+        check_visibility_can_be_modified: Optional[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = None,
+        on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
+    ) -> None:
         """Registers callbacks for third party event rules capabilities.
 
         Added in Synapse v1.39.0.
         """
-        return self._third_party_event_rules.register_third_party_rules_callbacks
+        return self._third_party_event_rules.register_third_party_rules_callbacks(
+            check_event_allowed=check_event_allowed,
+            on_create_room=on_create_room,
+            check_threepid_can_be_invited=check_threepid_can_be_invited,
+            check_visibility_can_be_modified=check_visibility_can_be_modified,
+            on_new_event=on_new_event,
+        )
 
-    @property
-    def register_presence_router_callbacks(self):
+    def register_presence_router_callbacks(
+        self,
+        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+    ) -> None:
         """Registers callbacks for presence router capabilities.
 
         Added in Synapse v1.42.0.
         """
-        return self._presence_router.register_presence_router_callbacks
+        return self._presence_router.register_presence_router_callbacks(
+            get_users_for_states=get_users_for_states,
+            get_interested_users=get_interested_users,
+        )
 
-    @property
-    def register_password_auth_provider_callbacks(self):
+    def register_password_auth_provider_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        auth_checkers: Optional[
+            Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
+        ] = None,
+    ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
         Added in Synapse v1.46.0.
         """
-        return self._password_auth_provider.register_password_auth_provider_callbacks
+        return self._password_auth_provider.register_password_auth_provider_callbacks(
+            check_3pid_auth=check_3pid_auth,
+            on_logged_out=on_logged_out,
+            auth_checkers=auth_checkers,
+        )
+
+    def register_background_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Registers background update controller callbacks.
 
-    def register_web_resource(self, path: str, resource: Resource):
+        Added in Synapse v1.49.0.
+        """
+
+        for db in self._hs.get_datastores().databases:
+            db.updates.register_update_controller_callbacks(
+                on_update=on_update,
+                default_batch_size=default_batch_size,
+                min_batch_size=min_batch_size,
+            )
+
+    def register_web_resource(self, path: str, resource: Resource) -> None:
         """Registers a web resource to be served at the given path.
 
         This function should be called during initialisation of the module.
@@ -216,7 +358,7 @@ class ModuleApi:
     # The following methods can be called by the module at any point in time.
 
     @property
-    def http_client(self):
+    def http_client(self) -> SimpleHttpClient:
         """Allows making outbound HTTP requests to remote resources.
 
         An instance of synapse.http.client.SimpleHttpClient
@@ -226,7 +368,7 @@ class ModuleApi:
         return self._http_client
 
     @property
-    def public_room_list_manager(self):
+    def public_room_list_manager(self) -> "PublicRoomListManager":
         """Allows adding to, removing from and checking the status of rooms in the
         public room list.
 
@@ -309,7 +451,7 @@ class ModuleApi:
         """
         return await self._store.is_server_admin(UserID.from_string(user_id))
 
-    def get_qualified_user_id(self, username):
+    def get_qualified_user_id(self, username: str) -> str:
         """Qualify a user id, if necessary
 
         Takes a user id provided by the user and adds the @ and :domain to
@@ -318,10 +460,10 @@ class ModuleApi:
         Added in Synapse v0.25.0.
 
         Args:
-            username (str): provided user id
+            username: provided user id
 
         Returns:
-            str: qualified @user:id
+            qualified @user:id
         """
         if username.startswith("@"):
             return username
@@ -357,22 +499,27 @@ class ModuleApi:
         """
         return await self._store.user_get_threepids(user_id)
 
-    def check_user_exists(self, user_id):
+    def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
         """Check if user exists.
 
         Added in Synapse v0.25.0.
 
         Args:
-            user_id (str): Complete @user:id
+            user_id: Complete @user:id
 
         Returns:
-            Deferred[str|None]: Canonical (case-corrected) user_id, or None
+            Canonical (case-corrected) user_id, or None
                if the user is not registered.
         """
         return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
 
     @defer.inlineCallbacks
-    def register(self, localpart, displayname=None, emails: Optional[List[str]] = None):
+    def register(
+        self,
+        localpart: str,
+        displayname: Optional[str] = None,
+        emails: Optional[List[str]] = None,
+    ) -> Generator["defer.Deferred[Any]", Any, Tuple[str, str]]:
         """Registers a new user with given localpart and optional displayname, emails.
 
         Also returns an access token for the new user.
@@ -384,12 +531,12 @@ class ModuleApi:
         Added in Synapse v0.25.0.
 
         Args:
-            localpart (str): The localpart of the new user.
-            displayname (str|None): The displayname of the new user.
-            emails (List[str]): Emails to bind to the new user.
+            localpart: The localpart of the new user.
+            displayname: The displayname of the new user.
+            emails: Emails to bind to the new user.
 
         Returns:
-            Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+            a 2-tuple of (user_id, access_token)
         """
         logger.warning(
             "Using deprecated ModuleApi.register which creates a dummy user device."
@@ -399,23 +546,26 @@ class ModuleApi:
         return user_id, access_token
 
     def register_user(
-        self, localpart, displayname=None, emails: Optional[List[str]] = None
-    ):
+        self,
+        localpart: str,
+        displayname: Optional[str] = None,
+        emails: Optional[List[str]] = None,
+    ) -> "defer.Deferred[str]":
         """Registers a new user with given localpart and optional displayname, emails.
 
         Added in Synapse v1.2.0.
 
         Args:
-            localpart (str): The localpart of the new user.
-            displayname (str|None): The displayname of the new user.
-            emails (List[str]): Emails to bind to the new user.
+            localpart: The localpart of the new user.
+            displayname: The displayname of the new user.
+            emails: Emails to bind to the new user.
 
         Raises:
             SynapseError if there is an error performing the registration. Check the
                 'errcode' property for more information on the reason for failure
 
         Returns:
-            defer.Deferred[str]: user_id
+            user_id
         """
         return defer.ensureDeferred(
             self._hs.get_registration_handler().register_user(
@@ -425,20 +575,25 @@ class ModuleApi:
             )
         )
 
-    def register_device(self, user_id, device_id=None, initial_display_name=None):
+    def register_device(
+        self,
+        user_id: str,
+        device_id: Optional[str] = None,
+        initial_display_name: Optional[str] = None,
+    ) -> "defer.Deferred[Tuple[str, str, Optional[int], Optional[str]]]":
         """Register a device for a user and generate an access token.
 
         Added in Synapse v1.2.0.
 
         Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
+            user_id: full canonical @user:id
+            device_id: The device ID to check, or None to generate
                 a new one.
-            initial_display_name (str|None): An optional display name for the
+            initial_display_name: An optional display name for the
                 device.
 
         Returns:
-            defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+            Tuple of device ID, access token, access token expiration time and refresh token
         """
         return defer.ensureDeferred(
             self._hs.get_registration_handler().register_device(
@@ -471,6 +626,7 @@ class ModuleApi:
         user_id: str,
         duration_in_ms: int = (2 * 60 * 1000),
         auth_provider_id: str = "",
+        auth_provider_session_id: Optional[str] = None,
     ) -> str:
         """Generate a login token suitable for m.login.token authentication
 
@@ -488,11 +644,14 @@ class ModuleApi:
         return self._hs.get_macaroon_generator().generate_short_term_login_token(
             user_id,
             auth_provider_id,
+            auth_provider_session_id,
             duration_in_ms,
         )
 
     @defer.inlineCallbacks
-    def invalidate_access_token(self, access_token):
+    def invalidate_access_token(
+        self, access_token: str
+    ) -> Generator["defer.Deferred[Any]", Any, None]:
         """Invalidate an access token for a user
 
         Added in Synapse v0.25.0.
@@ -524,14 +683,20 @@ class ModuleApi:
                 self._auth_handler.delete_access_token(access_token)
             )
 
-    def run_db_interaction(self, desc, func, *args, **kwargs):
+    def run_db_interaction(
+        self,
+        desc: str,
+        func: Callable[..., T],
+        *args: Any,
+        **kwargs: Any,
+    ) -> "defer.Deferred[T]":
         """Run a function with a database connection
 
         Added in Synapse v0.25.0.
 
         Args:
-            desc (str): description for the transaction, for metrics etc
-            func (func): function to be run. Passed a database cursor object
+            desc: description for the transaction, for metrics etc
+            func: function to be run. Passed a database cursor object
                 as well as *args and **kwargs
             *args: positional args to be passed to func
             **kwargs: named args to be passed to func
@@ -545,7 +710,7 @@ class ModuleApi:
 
     def complete_sso_login(
         self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
-    ):
+    ) -> None:
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
         URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -575,7 +740,7 @@ class ModuleApi:
         client_redirect_url: str,
         new_user: bool = False,
         auth_provider_id: str = "<unknown>",
-    ):
+    ) -> None:
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
         URL with a token directly if the URL matches with one of the whitelisted clients.
@@ -814,11 +979,11 @@ class ModuleApi:
         self,
         f: Callable,
         msec: float,
-        *args,
+        *args: object,
         desc: Optional[str] = None,
         run_on_all_instances: bool = False,
-        **kwargs,
-    ):
+        **kwargs: object,
+    ) -> None:
         """Wraps a function as a background process and calls it repeatedly.
 
         NOTE: Will only run on the instance that is configured to run
@@ -859,13 +1024,18 @@ class ModuleApi:
                 f,
             )
 
+    async def sleep(self, seconds: float) -> None:
+        """Sleeps for the given number of seconds."""
+
+        await self._clock.sleep(seconds)
+
     async def send_mail(
         self,
         recipient: str,
         subject: str,
         html: str,
         text: str,
-    ):
+    ) -> None:
         """Send an email on behalf of the homeserver.
 
         Added in Synapse v1.39.0.
@@ -903,7 +1073,7 @@ class ModuleApi:
             A list containing the loaded templates, with the orders matching the one of
             the filenames parameter.
         """
-        return self._hs.config.read_templates(
+        return self._hs.config.server.read_templates(
             filenames,
             (td for td in (self.custom_template_dir, custom_template_directory) if td),
         )
@@ -1013,6 +1183,26 @@ class ModuleApi:
 
         return {key: state_events[event_id] for key, event_id in state_ids.items()}
 
+    async def defer_to_thread(
+        self,
+        f: Callable[..., T],
+        *args: Any,
+        **kwargs: Any,
+    ) -> T:
+        """Runs the given function in a separate thread from Synapse's thread pool.
+
+        Added in Synapse v1.49.0.
+
+        Args:
+            f: The function to run.
+            args: The function's arguments.
+            kwargs: The function's keyword arguments.
+
+        Returns:
+            The return value of the function once ran in a thread.
+        """
+        return await defer_to_thread(self._hs.get_reactor(), f, *args, **kwargs)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index cf5abdfbda..4f13c0418a 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,6 +21,8 @@ from twisted.internet.interfaces import IDelayedCall
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams
 from synapse.push.mailer import Mailer
+from synapse.push.push_types import EmailReason
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
 from synapse.util.threepids import validate_email
 
 if TYPE_CHECKING:
@@ -190,7 +192,7 @@ class EmailPusher(Pusher):
                 # we then consider all previously outstanding notifications
                 # to be delivered.
 
-                reason = {
+                reason: EmailReason = {
                     "room_id": push_action["room_id"],
                     "now": self.clock.time_msec(),
                     "received_at": received_at,
@@ -275,7 +277,7 @@ class EmailPusher(Pusher):
         return may_send_at
 
     async def sent_notif_update_throttle(
-        self, room_id: str, notified_push_action: dict
+        self, room_id: str, notified_push_action: EmailPushAction
     ) -> None:
         # We have sent a notification, so update the throttle accordingly.
         # If the event that triggered the notif happened more than
@@ -315,7 +317,9 @@ class EmailPusher(Pusher):
             self.pusher_id, room_id, self.throttle_params[room_id]
         )
 
-    async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
+    async def send_notification(
+        self, push_actions: List[EmailPushAction], reason: EmailReason
+    ) -> None:
         logger.info("Sending notif email for user %r", self.user_id)
 
         await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index dbf4ad7f97..3fa603ccb7 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -26,6 +26,7 @@ from synapse.events import EventBase
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import Pusher, PusherConfig, PusherConfigException
+from synapse.storage.databases.main.event_push_actions import HttpPushAction
 
 from . import push_rule_evaluator, push_tools
 
@@ -273,7 +274,7 @@ class HttpPusher(Pusher):
                     )
                     break
 
-    async def _process_one(self, push_action: dict) -> bool:
+    async def _process_one(self, push_action: HttpPushAction) -> bool:
         if "notify" not in push_action["actions"]:
             return True
 
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ce299ba3da..ba4f866487 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -14,7 +14,7 @@
 
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
 
 import bleach
 import jinja2
@@ -28,6 +28,14 @@ from synapse.push.presentable_names import (
     descriptor_from_member_events,
     name_from_member_event,
 )
+from synapse.push.push_types import (
+    EmailReason,
+    MessageVars,
+    NotifVars,
+    RoomVars,
+    TemplateVars,
+)
+from synapse.storage.databases.main.event_push_actions import EmailPushAction
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap, UserID
 from synapse.util.async_helpers import concurrently_execute
@@ -135,7 +143,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -165,7 +173,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -196,7 +204,7 @@ class Mailer:
             % urllib.parse.urlencode(params)
         )
 
-        template_vars = {"link": link}
+        template_vars: TemplateVars = {"link": link}
 
         await self.send_email(
             email_address,
@@ -210,8 +218,8 @@ class Mailer:
         app_id: str,
         user_id: str,
         email_address: str,
-        push_actions: Iterable[Dict[str, Any]],
-        reason: Dict[str, Any],
+        push_actions: Iterable[EmailPushAction],
+        reason: EmailReason,
     ) -> None:
         """
         Send email regarding a user's room notifications
@@ -230,7 +238,7 @@ class Mailer:
             [pa["event_id"] for pa in push_actions]
         )
 
-        notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
+        notifs_by_room: Dict[str, List[EmailPushAction]] = {}
         for pa in push_actions:
             notifs_by_room.setdefault(pa["room_id"], []).append(pa)
 
@@ -258,7 +266,7 @@ class Mailer:
         # actually sort our so-called rooms_in_order list, most recent room first
         rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
 
-        rooms: List[Dict[str, Any]] = []
+        rooms: List[RoomVars] = []
 
         for r in rooms_in_order:
             roomvars = await self._get_room_vars(
@@ -289,7 +297,7 @@ class Mailer:
                 notifs_by_room, state_by_room, notif_events, reason
             )
 
-        template_vars = {
+        template_vars: TemplateVars = {
             "user_display_name": user_display_name,
             "unsubscribe_link": self._make_unsubscribe_link(
                 user_id, app_id, email_address
@@ -302,10 +310,10 @@ class Mailer:
         await self.send_email(email_address, summary_text, template_vars)
 
     async def send_email(
-        self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+        self, email_address: str, subject: str, extra_template_vars: TemplateVars
     ) -> None:
         """Send an email with the given information and template text"""
-        template_vars = {
+        template_vars: TemplateVars = {
             "app_name": self.app_name,
             "server_name": self.hs.config.server.server_name,
         }
@@ -327,10 +335,10 @@ class Mailer:
         self,
         room_id: str,
         user_id: str,
-        notifs: Iterable[Dict[str, Any]],
+        notifs: Iterable[EmailPushAction],
         notif_events: Dict[str, EventBase],
         room_state_ids: StateMap[str],
-    ) -> Dict[str, Any]:
+    ) -> RoomVars:
         """
         Generate the variables for notifications on a per-room basis.
 
@@ -356,7 +364,7 @@ class Mailer:
 
         room_name = await calculate_room_name(self.store, room_state_ids, user_id)
 
-        room_vars: Dict[str, Any] = {
+        room_vars: RoomVars = {
             "title": room_name,
             "hash": string_ordinal_total(room_id),  # See sender avatar hash
             "notifs": [],
@@ -417,11 +425,11 @@ class Mailer:
 
     async def _get_notif_vars(
         self,
-        notif: Dict[str, Any],
+        notif: EmailPushAction,
         user_id: str,
         notif_event: EventBase,
         room_state_ids: StateMap[str],
-    ) -> Dict[str, Any]:
+    ) -> NotifVars:
         """
         Generate the variables for a single notification.
 
@@ -442,7 +450,7 @@ class Mailer:
             after_limit=CONTEXT_AFTER,
         )
 
-        ret = {
+        ret: NotifVars = {
             "link": self._make_notif_link(notif),
             "ts": notif["received_ts"],
             "messages": [],
@@ -461,8 +469,8 @@ class Mailer:
         return ret
 
     async def _get_message_vars(
-        self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
-    ) -> Optional[Dict[str, Any]]:
+        self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str]
+    ) -> Optional[MessageVars]:
         """
         Generate the variables for a single event, if possible.
 
@@ -494,7 +502,9 @@ class Mailer:
 
         if sender_state_event:
             sender_name = name_from_member_event(sender_state_event)
-            sender_avatar_url = sender_state_event.content.get("avatar_url")
+            sender_avatar_url: Optional[str] = sender_state_event.content.get(
+                "avatar_url"
+            )
         else:
             # No state could be found, fallback to the MXID.
             sender_name = event.sender
@@ -504,7 +514,7 @@ class Mailer:
         # sender_hash % the number of default images to choose from
         sender_hash = string_ordinal_total(event.sender)
 
-        ret = {
+        ret: MessageVars = {
             "event_type": event.type,
             "is_historical": event.event_id != notif["event_id"],
             "id": event.event_id,
@@ -519,6 +529,8 @@ class Mailer:
             return ret
 
         msgtype = event.content.get("msgtype")
+        if not isinstance(msgtype, str):
+            msgtype = None
 
         ret["msgtype"] = msgtype
 
@@ -533,7 +545,7 @@ class Mailer:
         return ret
 
     def _add_text_message_vars(
-        self, messagevars: Dict[str, Any], event: EventBase
+        self, messagevars: MessageVars, event: EventBase
     ) -> None:
         """
         Potentially add a sanitised message body to the message variables.
@@ -543,8 +555,8 @@ class Mailer:
             event: The event under consideration.
         """
         msgformat = event.content.get("format")
-
-        messagevars["format"] = msgformat
+        if not isinstance(msgformat, str):
+            msgformat = None
 
         formatted_body = event.content.get("formatted_body")
         body = event.content.get("body")
@@ -555,7 +567,7 @@ class Mailer:
             messagevars["body_text_html"] = safe_text(body)
 
     def _add_image_message_vars(
-        self, messagevars: Dict[str, Any], event: EventBase
+        self, messagevars: MessageVars, event: EventBase
     ) -> None:
         """
         Potentially add an image URL to the message variables.
@@ -570,7 +582,7 @@ class Mailer:
     async def _make_summary_text_single_room(
         self,
         room_id: str,
-        notifs: List[Dict[str, Any]],
+        notifs: List[EmailPushAction],
         room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
         user_id: str,
@@ -685,10 +697,10 @@ class Mailer:
 
     async def _make_summary_text(
         self,
-        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        notifs_by_room: Dict[str, List[EmailPushAction]],
         room_state_ids: Dict[str, StateMap[str]],
         notif_events: Dict[str, EventBase],
-        reason: Dict[str, Any],
+        reason: EmailReason,
     ) -> str:
         """
         Make a summary text for the email when multiple rooms have notifications.
@@ -718,7 +730,7 @@ class Mailer:
     async def _make_summary_text_from_member_events(
         self,
         room_id: str,
-        notifs: List[Dict[str, Any]],
+        notifs: List[EmailPushAction],
         room_state_ids: StateMap[str],
         notif_events: Dict[str, EventBase],
     ) -> str:
@@ -805,7 +817,7 @@ class Mailer:
             base_url = "https://matrix.to/#"
         return "%s/%s" % (base_url, room_id)
 
-    def _make_notif_link(self, notif: Dict[str, str]) -> str:
+    def _make_notif_link(self, notif: EmailPushAction) -> str:
         """
         Generate a link to open an event in the web client.
 
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7f68092ec5..659a53805d 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,9 +17,10 @@ import logging
 import re
 from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
 
+from matrix_common.regex import glob_to_regex, to_word_pattern
+
 from synapse.events import EventBase
 from synapse.types import JsonDict, UserID
-from synapse.util import glob_to_regex, re_word_boundary
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent:
         r = regex_cache.get((display_name, False, True), None)
         if not r:
             r1 = re.escape(display_name)
-            r1 = re_word_boundary(r1)
+            r1 = to_word_pattern(r1)
             r = re.compile(r1, flags=re.IGNORECASE)
             regex_cache[(display_name, False, True)] = r
 
@@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
     try:
         r = regex_cache.get((glob, True, word_boundary), None)
         if not r:
-            r = glob_to_regex(glob, word_boundary)
+            r = glob_to_regex(glob, word_boundary=word_boundary)
             regex_cache[(glob, True, word_boundary)] = r
         return bool(r.search(value))
     except re.error:
diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py
new file mode 100644
index 0000000000..8d16ab62ce
--- /dev/null
+++ b/synapse/push/push_types.py
@@ -0,0 +1,136 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional
+
+from typing_extensions import TypedDict
+
+
+class EmailReason(TypedDict, total=False):
+    """
+    Information on the event that triggered the email to be sent
+
+    room_id: the ID of the room the event was sent in
+    now: timestamp in ms when the email is being sent out
+    room_name: a human-readable name for the room the event was sent in
+    received_at: the time in milliseconds at which the event was received
+    delay_before_mail_ms: the amount of time in milliseconds Synapse always waits
+            before ever emailing about a notification (to give the user a chance to respond
+            to other push or notice the window)
+    last_sent_ts: the time in milliseconds at which a notification was last sent
+            for an event in this room
+    throttle_ms: the minimum amount of time in milliseconds between two
+            notifications can be sent for this room
+    """
+
+    room_id: str
+    now: int
+    room_name: Optional[str]
+    received_at: int
+    delay_before_mail_ms: int
+    last_sent_ts: int
+    throttle_ms: int
+
+
+class MessageVars(TypedDict, total=False):
+    """
+    Details about a specific message to include in a notification
+
+    event_type: the type of the event
+    is_historical: a boolean, which is `False` if the message is the one
+                that triggered the notification, `True` otherwise
+    id: the ID of the event
+    ts: the time in milliseconds at which the event was sent
+    sender_name: the display name for the event's sender
+    sender_avatar_url: the avatar URL (as a `mxc://` URL) for the event's
+                sender
+    sender_hash: a hash of the user ID of the sender
+    msgtype: the type of the message
+    body_text_html: html representation of the message
+    body_text_plain: plaintext representation of the message
+    image_url: mxc url of an image, when "msgtype" is "m.image"
+    """
+
+    event_type: str
+    is_historical: bool
+    id: str
+    ts: int
+    sender_name: str
+    sender_avatar_url: Optional[str]
+    sender_hash: int
+    msgtype: Optional[str]
+    body_text_html: str
+    body_text_plain: str
+    image_url: str
+
+
+class NotifVars(TypedDict):
+    """
+    Details about an event we are about to include in a notification
+
+    link: a `matrix.to` link to the event
+    ts: the time in milliseconds at which the event was received
+    messages: a list of messages containing one message before the event, the
+              message in the event, and one message after the event.
+    """
+
+    link: str
+    ts: Optional[int]
+    messages: List[MessageVars]
+
+
+class RoomVars(TypedDict):
+    """
+    Represents a room containing events to include in the email.
+
+    title: a human-readable name for the room
+    hash: a hash of the ID of the room
+    invite: a boolean, which is `True` if the room is an invite the user hasn't
+        accepted yet, `False` otherwise
+    notifs: a list of events, or an empty list if `invite` is `True`.
+    link: a `matrix.to` link to the room
+    avator_url: url to the room's avator
+    """
+
+    title: Optional[str]
+    hash: int
+    invite: bool
+    notifs: List[NotifVars]
+    link: str
+    avatar_url: Optional[str]
+
+
+class TemplateVars(TypedDict, total=False):
+    """
+    Generic structure for passing to the email sender, can hold all the fields used in email templates.
+
+    app_name: name of the app/service this homeserver is associated with
+    server_name: name of our own homeserver
+    link: a link to include into the email to be sent
+    user_display_name: the display name for the user receiving the notification
+    unsubscribe_link: the link users can click to unsubscribe from email notifications
+    summary_text: a summary of the notification(s). The text used can be customised
+              by configuring the various settings in the `email.subjects` section of the
+              configuration file.
+    rooms: a list of rooms containing events to include in the email
+    reason: information on the event that triggered the email to be sent
+    """
+
+    app_name: str
+    server_name: str
+    link: str
+    user_display_name: str
+    unsubscribe_link: str
+    summary_text: str
+    rooms: List[RoomVars]
+    reason: EmailReason
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 154e5b7028..386debd7db 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,7 +86,8 @@ REQUIREMENTS = [
     # We enforce that we have a `cryptography` version that bundles an `openssl`
     # with the latest security patches.
     "cryptography>=3.4.7",
-    "ijson>=3.0",
+    "ijson>=3.1",
+    "matrix-common==1.0.0",
 ]
 
 CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 0db419ea57..daacc34cea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest,
         is_appservice_ghost,
         should_issue_refresh_token,
+        auth_provider_id,
+        auth_provider_session_id,
     ):
         """
         Args:
@@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "is_guest": is_guest,
             "is_appservice_ghost": is_appservice_ghost,
             "should_issue_refresh_token": should_issue_refresh_token,
+            "auth_provider_id": auth_provider_id,
+            "auth_provider_session_id": auth_provider_session_id,
         }
 
     async def _handle_request(self, request, user_id):
@@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest = content["is_guest"]
         is_appservice_ghost = content["is_appservice_ghost"]
         should_issue_refresh_token = content["should_issue_refresh_token"]
+        auth_provider_id = content["auth_provider_id"]
+        auth_provider_session_id = content["auth_provider_session_id"]
 
         res = await self.registration_handler.register_device_inner(
             user_id,
@@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         return 200, res
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
 from typing import List, Optional, Tuple
 
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
 
 
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+    """Tracks the "current" stream ID of a stream with a single writer.
+
+    See `AbstractStreamIdTracker` for more details.
+
+    Note that this class does not work correctly when there are multiple
+    writers.
+    """
+
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self) -> int:
-        """
-
-        Returns:
-            int
-        """
         return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        # We assert this for the benefit of mypy
-        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
 
 import attr
 
@@ -157,7 +157,7 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+        event_rows = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
         )
 
@@ -191,7 +191,7 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
         )
 
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d78fe406c4..c499afd4be 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -17,6 +17,7 @@
 
 import logging
 import platform
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Optional, Tuple
 
 import synapse
@@ -28,6 +29,7 @@ from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.rest.admin.background_updates import (
     BackgroundUpdateEnabledRestServlet,
     BackgroundUpdateRestServlet,
+    BackgroundUpdateStartJobRestServlet,
 )
 from synapse.rest.admin.devices import (
     DeleteDevicesRestServlet,
@@ -38,6 +40,10 @@ from synapse.rest.admin.event_reports import (
     EventReportDetailRestServlet,
     EventReportsRestServlet,
 )
+from synapse.rest.admin.federation import (
+    DestinationsRestServlet,
+    ListDestinationsRestServlet,
+)
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
 from synapse.rest.admin.registration_tokens import (
@@ -46,6 +52,7 @@ from synapse.rest.admin.registration_tokens import (
     RegistrationTokenRestServlet,
 )
 from synapse.rest.admin.rooms import (
+    BlockRoomRestServlet,
     DeleteRoomStatusByDeleteIdRestServlet,
     DeleteRoomStatusByRoomIdRestServlet,
     ForwardExtremitiesRestServlet,
@@ -96,7 +103,7 @@ class VersionServlet(RestServlet):
         }
 
     def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        return 200, self.res
+        return HTTPStatus.OK, self.res
 
 
 class PurgeHistoryRestServlet(RestServlet):
@@ -128,7 +135,7 @@ class PurgeHistoryRestServlet(RestServlet):
             event = await self.store.get_event(event_id)
 
             if event.room_id != room_id:
-                raise SynapseError(400, "Event is for wrong room.")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.")
 
             # RoomStreamToken expects [int] not Optional[int]
             assert event.internal_metadata.stream_ordering is not None
@@ -142,7 +149,9 @@ class PurgeHistoryRestServlet(RestServlet):
             ts = body["purge_up_to_ts"]
             if not isinstance(ts, int):
                 raise SynapseError(
-                    400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
+                    HTTPStatus.BAD_REQUEST,
+                    "purge_up_to_ts must be an int",
+                    errcode=Codes.BAD_JSON,
                 )
 
             stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
@@ -158,7 +167,9 @@ class PurgeHistoryRestServlet(RestServlet):
                     stream_ordering,
                 )
                 raise SynapseError(
-                    404, "there is no event to be purged", errcode=Codes.NOT_FOUND
+                    HTTPStatus.NOT_FOUND,
+                    "there is no event to be purged",
+                    errcode=Codes.NOT_FOUND,
                 )
             (stream, topo, _event_id) = r
             token = "t%d-%d" % (topo, stream)
@@ -171,7 +182,7 @@ class PurgeHistoryRestServlet(RestServlet):
             )
         else:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "must specify purge_up_to_event_id or purge_up_to_ts",
                 errcode=Codes.BAD_JSON,
             )
@@ -180,7 +191,7 @@ class PurgeHistoryRestServlet(RestServlet):
             room_id, token, delete_local_events=delete_local_events
         )
 
-        return 200, {"purge_id": purge_id}
+        return HTTPStatus.OK, {"purge_id": purge_id}
 
 
 class PurgeHistoryStatusRestServlet(RestServlet):
@@ -199,7 +210,7 @@ class PurgeHistoryStatusRestServlet(RestServlet):
         if purge_status is None:
             raise NotFoundError("purge id '%s' not found" % purge_id)
 
-        return 200, purge_status.asdict()
+        return HTTPStatus.OK, purge_status.asdict()
 
 
 ########################################################################################
@@ -223,6 +234,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     Register all the admin servlets.
     """
     register_servlets_for_client_rest_resource(hs, http_server)
+    BlockRoomRestServlet(hs).register(http_server)
     ListRoomRestServlet(hs).register(http_server)
     RoomStateRestServlet(hs).register(http_server)
     RoomRestServlet(hs).register(http_server)
@@ -253,12 +265,15 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ListRegistrationTokensRestServlet(hs).register(http_server)
     NewRegistrationTokenRestServlet(hs).register(http_server)
     RegistrationTokenRestServlet(hs).register(http_server)
+    DestinationsRestServlet(hs).register(http_server)
+    ListDestinationsRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if hs.config.worker.worker_app is None:
         SendServerNoticeServlet(hs).register(http_server)
         BackgroundUpdateEnabledRestServlet(hs).register(http_server)
         BackgroundUpdateRestServlet(hs).register(http_server)
+        BackgroundUpdateStartJobRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d9a2f6ca15..399b205aaf 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import re
+from http import HTTPStatus
 from typing import Iterable, Pattern
 
 from synapse.api.auth import Auth
@@ -62,4 +63,4 @@ async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
     """
     is_admin = await auth.is_server_admin(user_id)
     if not is_admin:
-        raise AuthError(403, "You are not a server admin")
+        raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 0d0183bf20..479672d4d5 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -12,10 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+    RestServlet,
+    assert_params_in_dict,
+    parse_json_object_from_request,
+)
 from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
 from synapse.types import JsonDict
@@ -29,37 +34,36 @@ logger = logging.getLogger(__name__)
 class BackgroundUpdateEnabledRestServlet(RestServlet):
     """Allows temporarily disabling background updates"""
 
-    PATTERNS = admin_patterns("/background_updates/enabled")
+    PATTERNS = admin_patterns("/background_updates/enabled$")
 
     def __init__(self, hs: "HomeServer"):
-        self.group_server = hs.get_groups_server_handler()
-        self.is_mine_id = hs.is_mine_id
-        self.auth = hs.get_auth()
-
-        self.data_stores = hs.get_datastores()
+        self._auth = hs.get_auth()
+        self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester.user)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
-        enabled = all(db.updates.enabled for db in self.data_stores.databases)
+        enabled = all(db.updates.enabled for db in self._data_stores.databases)
 
-        return 200, {"enabled": enabled}
+        return HTTPStatus.OK, {"enabled": enabled}
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester.user)
 
         body = parse_json_object_from_request(request)
 
         enabled = body.get("enabled", True)
 
         if not isinstance(enabled, bool):
-            raise SynapseError(400, "'enabled' parameter must be a boolean")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'enabled' parameter must be a boolean"
+            )
 
-        for db in self.data_stores.databases:
+        for db in self._data_stores.databases:
             db.updates.enabled = enabled
 
             # If we're re-enabling them ensure that we start the background
@@ -67,32 +71,29 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
             if enabled:
                 db.updates.start_doing_background_updates()
 
-        return 200, {"enabled": enabled}
+        return HTTPStatus.OK, {"enabled": enabled}
 
 
 class BackgroundUpdateRestServlet(RestServlet):
     """Fetch information about background updates"""
 
-    PATTERNS = admin_patterns("/background_updates/status")
+    PATTERNS = admin_patterns("/background_updates/status$")
 
     def __init__(self, hs: "HomeServer"):
-        self.group_server = hs.get_groups_server_handler()
-        self.is_mine_id = hs.is_mine_id
-        self.auth = hs.get_auth()
-
-        self.data_stores = hs.get_datastores()
+        self._auth = hs.get_auth()
+        self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester.user)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
-        enabled = all(db.updates.enabled for db in self.data_stores.databases)
+        enabled = all(db.updates.enabled for db in self._data_stores.databases)
 
         current_updates = {}
 
-        for db in self.data_stores.databases:
+        for db in self._data_stores.databases:
             update = db.updates.get_current_update()
             if not update:
                 continue
@@ -104,4 +105,72 @@ class BackgroundUpdateRestServlet(RestServlet):
                 "average_items_per_ms": update.average_items_per_ms(),
             }
 
-        return 200, {"enabled": enabled, "current_updates": current_updates}
+        return HTTPStatus.OK, {"enabled": enabled, "current_updates": current_updates}
+
+
+class BackgroundUpdateStartJobRestServlet(RestServlet):
+    """Allows to start specific background updates"""
+
+    PATTERNS = admin_patterns("/background_updates/start_job")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester.user)
+
+        body = parse_json_object_from_request(request)
+        assert_params_in_dict(body, ["job_name"])
+
+        job_name = body["job_name"]
+
+        if job_name == "populate_stats_process_rooms":
+            jobs = [
+                {
+                    "update_name": "populate_stats_process_rooms",
+                    "progress_json": "{}",
+                },
+            ]
+        elif job_name == "regenerate_directory":
+            jobs = [
+                {
+                    "update_name": "populate_user_directory_createtables",
+                    "progress_json": "{}",
+                    "depends_on": "",
+                },
+                {
+                    "update_name": "populate_user_directory_process_rooms",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_createtables",
+                },
+                {
+                    "update_name": "populate_user_directory_process_users",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_process_rooms",
+                },
+                {
+                    "update_name": "populate_user_directory_cleanup",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_process_users",
+                },
+            ]
+        else:
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name")
+
+        try:
+            await self._store.db_pool.simple_insert_many(
+                table="background_updates",
+                values=jobs,
+                desc=f"admin_api_run_{job_name}",
+            )
+        except self._store.db_pool.engine.module.IntegrityError:
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Job %s is already in queue of background updates." % (job_name,),
+            )
+
+        self._store.db_pool.updates.start_doing_background_updates()
+
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 80fbf32f17..2e5a6600d3 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import NotFoundError, SynapseError
@@ -53,7 +54,7 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -62,7 +63,7 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             target_user.to_string(), device_id
         )
-        return 200, device
+        return HTTPStatus.OK, device
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -71,14 +72,14 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
             raise NotFoundError("Unknown user")
 
         await self.device_handler.delete_device(target_user.to_string(), device_id)
-        return 200, {}
+        return HTTPStatus.OK, {}
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -87,7 +88,7 @@ class DeviceRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -97,7 +98,7 @@ class DeviceRestServlet(RestServlet):
         await self.device_handler.update_device(
             target_user.to_string(), device_id, body
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class DevicesRestServlet(RestServlet):
@@ -124,14 +125,14 @@ class DevicesRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
             raise NotFoundError("Unknown user")
 
         devices = await self.device_handler.get_devices_by_user(target_user.to_string())
-        return 200, {"devices": devices, "total": len(devices)}
+        return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
 
 
 class DeleteDevicesRestServlet(RestServlet):
@@ -155,7 +156,7 @@ class DeleteDevicesRestServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only lookup local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
         if u is None:
@@ -167,4 +168,4 @@ class DeleteDevicesRestServlet(RestServlet):
         await self.device_handler.delete_devices(
             target_user.to_string(), body["devices"]
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index bbfcaf723b..5ee8b11110 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -66,21 +67,23 @@ class EventReportsRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "The start parameter must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "The limit parameter must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         event_reports, total = await self.store.get_event_reports_paginate(
@@ -90,7 +93,7 @@ class EventReportsRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(event_reports)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class EventReportDetailRestServlet(RestServlet):
@@ -127,13 +130,17 @@ class EventReportDetailRestServlet(RestServlet):
         try:
             resolved_report_id = int(report_id)
         except ValueError:
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
 
         if resolved_report_id < 0:
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
 
         ret = await self.store.get_event_report(resolved_report_id)
         if not ret:
             raise NotFoundError("Event report not found")
 
-        return 200, ret
+        return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
new file mode 100644
index 0000000000..744687be35
--- /dev/null
+++ b/synapse/rest/admin/federation.py
@@ -0,0 +1,135 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.storage.databases.main.transactions import DestinationSortOrder
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListDestinationsRestServlet(RestServlet):
+    """Get request to list all destinations.
+    This needs user to have administrator access in Synapse.
+
+    GET /_synapse/admin/v1/federation/destinations?from=0&limit=10
+
+    returns:
+        200 OK with list of destinations if success otherwise an error.
+
+    The parameters `from` and `limit` are required only for pagination.
+    By default, a `limit` of 100 is used.
+    The parameter `destination` can be used to filter by destination.
+    The parameter `order_by` can be used to order the result.
+    """
+
+    PATTERNS = admin_patterns("/federation/destinations$")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self._auth, request)
+
+        start = parse_integer(request, "from", default=0)
+        limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        destination = parse_string(request, "destination")
+
+        order_by = parse_string(
+            request,
+            "order_by",
+            default=DestinationSortOrder.DESTINATION.value,
+            allowed_values=[dest.value for dest in DestinationSortOrder],
+        )
+
+        direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+
+        destinations, total = await self._store.get_destinations_paginate(
+            start, limit, destination, order_by, direction
+        )
+        response = {"destinations": destinations, "total": total}
+        if (start + limit) < total:
+            response["next_token"] = str(start + len(destinations))
+
+        return HTTPStatus.OK, response
+
+
+class DestinationsRestServlet(RestServlet):
+    """Get details of a destination.
+    This needs user to have administrator access in Synapse.
+
+    GET /_synapse/admin/v1/federation/destinations/<destination>
+
+    returns:
+        200 OK with details of a destination if success otherwise an error.
+    """
+
+    PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+
+    async def on_GET(
+        self, request: SynapseRequest, destination: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self._auth, request)
+
+        destination_retry_timings = await self._store.get_destination_retry_timings(
+            destination
+        )
+
+        if not destination_retry_timings:
+            raise NotFoundError("Unknown destination")
+
+        last_successful_stream_ordering = (
+            await self._store.get_destination_last_successful_stream_ordering(
+                destination
+            )
+        )
+
+        response = {
+            "destination": destination,
+            "failure_ts": destination_retry_timings.failure_ts,
+            "retry_last_ts": destination_retry_timings.retry_last_ts,
+            "retry_interval": destination_retry_timings.retry_interval,
+            "last_successful_stream_ordering": last_successful_stream_ordering,
+        }
+
+        return HTTPStatus.OK, response
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 68a3ba3cb7..a27110388f 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import SynapseError
@@ -43,7 +44,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
 
         if not self.is_mine_id(group_id):
-            raise SynapseError(400, "Can only delete local groups")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups")
 
         await self.group_server.delete_group(group_id, requester.user.to_string())
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 30a687d234..9e23e2d8fc 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
@@ -62,7 +63,7 @@ class QuarantineMediaInRoom(RestServlet):
             room_id, requester.user.to_string()
         )
 
-        return 200, {"num_quarantined": num_quarantined}
+        return HTTPStatus.OK, {"num_quarantined": num_quarantined}
 
 
 class QuarantineMediaByUser(RestServlet):
@@ -89,7 +90,7 @@ class QuarantineMediaByUser(RestServlet):
             user_id, requester.user.to_string()
         )
 
-        return 200, {"num_quarantined": num_quarantined}
+        return HTTPStatus.OK, {"num_quarantined": num_quarantined}
 
 
 class QuarantineMediaByID(RestServlet):
@@ -118,7 +119,7 @@ class QuarantineMediaByID(RestServlet):
             server_name, media_id, requester.user.to_string()
         )
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UnquarantineMediaByID(RestServlet):
@@ -147,7 +148,7 @@ class UnquarantineMediaByID(RestServlet):
         # Remove from quarantine this media id
         await self.store.quarantine_media_by_id(server_name, media_id, None)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ProtectMediaByID(RestServlet):
@@ -170,7 +171,7 @@ class ProtectMediaByID(RestServlet):
         # Protect this media id
         await self.store.mark_local_media_as_safe(media_id, safe=True)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UnprotectMediaByID(RestServlet):
@@ -193,7 +194,7 @@ class UnprotectMediaByID(RestServlet):
         # Unprotect this media id
         await self.store.mark_local_media_as_safe(media_id, safe=False)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ListMediaInRoom(RestServlet):
@@ -211,11 +212,11 @@ class ListMediaInRoom(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
-            raise AuthError(403, "You are not a server admin")
+            raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
         local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
 
-        return 200, {"local": local_mxcs, "remote": remote_mxcs}
+        return HTTPStatus.OK, {"local": local_mxcs, "remote": remote_mxcs}
 
 
 class PurgeMediaCacheRestServlet(RestServlet):
@@ -233,13 +234,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
 
         if before_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
         elif before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts you provided is from the year 1970. "
                 + "Double check that you are providing a timestamp in milliseconds.",
                 errcode=Codes.INVALID_PARAM,
@@ -247,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
 
         ret = await self.media_repository.delete_old_remote_media(before_ts)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class DeleteMediaByID(RestServlet):
@@ -267,7 +268,7 @@ class DeleteMediaByID(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
-            raise SynapseError(400, "Can only delete local media")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         if await self.store.get_local_media(media_id) is None:
             raise NotFoundError("Unknown media")
@@ -277,7 +278,7 @@ class DeleteMediaByID(RestServlet):
         deleted_media, total = await self.media_repository.delete_local_media_ids(
             [media_id]
         )
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 class DeleteMediaByDateSize(RestServlet):
@@ -304,26 +305,26 @@ class DeleteMediaByDateSize(RestServlet):
 
         if before_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts must be a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
         elif before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts you provided is from the year 1970. "
                 + "Double check that you are providing a timestamp in milliseconds.",
                 errcode=Codes.INVALID_PARAM,
             )
         if size_gt < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter size_gt must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if self.server_name != server_name:
-            raise SynapseError(400, "Can only delete local media")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         logging.info(
             "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s"
@@ -333,7 +334,7 @@ class DeleteMediaByDateSize(RestServlet):
         deleted_media, total = await self.media_repository.delete_old_local_media(
             before_ts, size_gt, keep_profiles
         )
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 class UserMediaRestServlet(RestServlet):
@@ -369,7 +370,7 @@ class UserMediaRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         user = await self.store.get_user_by_id(user_id)
         if user is None:
@@ -380,14 +381,14 @@ class UserMediaRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -425,7 +426,7 @@ class UserMediaRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(media)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -436,7 +437,7 @@ class UserMediaRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         user = await self.store.get_user_by_id(user_id)
         if user is None:
@@ -447,14 +448,14 @@ class UserMediaRestServlet(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -492,7 +493,7 @@ class UserMediaRestServlet(RestServlet):
             ([row["media_id"] for row in media])
         )
 
-        return 200, {"deleted_media": deleted_media, "total": total}
+        return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
 
 
 def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index aba48f6e7b..891b98c088 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -14,6 +14,7 @@
 
 import logging
 import string
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -77,7 +78,7 @@ class ListRegistrationTokensRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
         valid = parse_boolean(request, "valid")
         token_list = await self.store.get_registration_tokens(valid)
-        return 200, {"registration_tokens": token_list}
+        return HTTPStatus.OK, {"registration_tokens": token_list}
 
 
 class NewRegistrationTokenRestServlet(RestServlet):
@@ -123,16 +124,20 @@ class NewRegistrationTokenRestServlet(RestServlet):
         if "token" in body:
             token = body["token"]
             if not isinstance(token, str):
-                raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "token must be a string",
+                    Codes.INVALID_PARAM,
+                )
             if not (0 < len(token) <= 64):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "token must not be empty and must not be longer than 64 characters",
                     Codes.INVALID_PARAM,
                 )
             if not set(token).issubset(self.allowed_chars_set):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
                     Codes.INVALID_PARAM,
                 )
@@ -142,11 +147,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
             length = body.get("length", 16)
             if not isinstance(length, int):
                 raise SynapseError(
-                    400, "length must be an integer", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "length must be an integer",
+                    Codes.INVALID_PARAM,
                 )
             if not (0 < length <= 64):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "length must be greater than zero and not greater than 64",
                     Codes.INVALID_PARAM,
                 )
@@ -162,7 +169,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
             or (isinstance(uses_allowed, int) and uses_allowed >= 0)
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "uses_allowed must be a non-negative integer or null",
                 Codes.INVALID_PARAM,
             )
@@ -170,11 +177,15 @@ class NewRegistrationTokenRestServlet(RestServlet):
         expiry_time = body.get("expiry_time", None)
         if not isinstance(expiry_time, (int, type(None))):
             raise SynapseError(
-                400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "expiry_time must be an integer or null",
+                Codes.INVALID_PARAM,
             )
         if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
             raise SynapseError(
-                400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "expiry_time must not be in the past",
+                Codes.INVALID_PARAM,
             )
 
         created = await self.store.create_registration_token(
@@ -182,7 +193,9 @@ class NewRegistrationTokenRestServlet(RestServlet):
         )
         if not created:
             raise SynapseError(
-                400, f"Token already exists: {token}", Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                f"Token already exists: {token}",
+                Codes.INVALID_PARAM,
             )
 
         resp = {
@@ -192,7 +205,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
             "completed": 0,
             "expiry_time": expiry_time,
         }
-        return 200, resp
+        return HTTPStatus.OK, resp
 
 
 class RegistrationTokenRestServlet(RestServlet):
@@ -261,7 +274,7 @@ class RegistrationTokenRestServlet(RestServlet):
         if token_info is None:
             raise NotFoundError(f"No such registration token: {token}")
 
-        return 200, token_info
+        return HTTPStatus.OK, token_info
 
     async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
         """Update a registration token."""
@@ -277,7 +290,7 @@ class RegistrationTokenRestServlet(RestServlet):
                 or (isinstance(uses_allowed, int) and uses_allowed >= 0)
             ):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "uses_allowed must be a non-negative integer or null",
                     Codes.INVALID_PARAM,
                 )
@@ -287,11 +300,15 @@ class RegistrationTokenRestServlet(RestServlet):
             expiry_time = body["expiry_time"]
             if not isinstance(expiry_time, (int, type(None))):
                 raise SynapseError(
-                    400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "expiry_time must be an integer or null",
+                    Codes.INVALID_PARAM,
                 )
             if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
                 raise SynapseError(
-                    400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                    HTTPStatus.BAD_REQUEST,
+                    "expiry_time must not be in the past",
+                    Codes.INVALID_PARAM,
                 )
             new_attributes["expiry_time"] = expiry_time
 
@@ -307,7 +324,7 @@ class RegistrationTokenRestServlet(RestServlet):
         if token_info is None:
             raise NotFoundError(f"No such registration token: {token}")
 
-        return 200, token_info
+        return HTTPStatus.OK, token_info
 
     async def on_DELETE(
         self, request: SynapseRequest, token: str
@@ -316,6 +333,6 @@ class RegistrationTokenRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if await self.store.delete_registration_token(token):
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 37cb4d0796..669ab44a45 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -102,7 +102,9 @@ class RoomRestV2Servlet(RestServlet):
             )
 
         if not RoomID.is_valid(room_id):
-            raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
 
         if not await self._store.get_room(room_id):
             raise NotFoundError("Unknown room id %s" % (room_id,))
@@ -118,7 +120,7 @@ class RoomRestV2Servlet(RestServlet):
             force_purge=force_purge,
         )
 
-        return 200, {"delete_id": delete_id}
+        return HTTPStatus.OK, {"delete_id": delete_id}
 
 
 class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
@@ -137,7 +139,9 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
         await assert_requester_is_admin(self._auth, request)
 
         if not RoomID.is_valid(room_id):
-            raise SynapseError(400, "%s is not a legal room ID" % (room_id,))
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
 
         delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id)
         if delete_ids is None:
@@ -153,7 +157,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
                         **delete.asdict(),
                     }
                 ]
-        return 200, {"results": cast(JsonDict, response)}
+        return HTTPStatus.OK, {"results": cast(JsonDict, response)}
 
 
 class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
@@ -175,7 +179,7 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
         if delete_status is None:
             raise NotFoundError("delete id '%s' not found" % delete_id)
 
-        return 200, cast(JsonDict, delete_status.asdict())
+        return HTTPStatus.OK, cast(JsonDict, delete_status.asdict())
 
 
 class ListRoomRestServlet(RestServlet):
@@ -217,7 +221,7 @@ class ListRoomRestServlet(RestServlet):
             RoomSortOrder.STATE_EVENTS.value,
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Unknown value for order_by: %s" % (order_by,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -225,7 +229,7 @@ class ListRoomRestServlet(RestServlet):
         search_term = parse_string(request, "search_term", encoding="utf-8")
         if search_term == "":
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "search_term cannot be an empty string",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -233,7 +237,9 @@ class ListRoomRestServlet(RestServlet):
         direction = parse_string(request, "dir", default="f")
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         reverse_order = True if direction == "b" else False
@@ -265,7 +271,7 @@ class ListRoomRestServlet(RestServlet):
             else:
                 response["prev_batch"] = 0
 
-        return 200, response
+        return HTTPStatus.OK, response
 
 
 class RoomRestServlet(RestServlet):
@@ -310,7 +316,7 @@ class RoomRestServlet(RestServlet):
         members = await self.store.get_users_in_room(room_id)
         ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, room_id: str
@@ -386,7 +392,7 @@ class RoomRestServlet(RestServlet):
         # See https://github.com/python/mypy/issues/4976#issuecomment-579883622
         # for some discussion on why this is necessary. Either way,
         # `ret` is an opaque dictionary blob as far as the rest of the app cares.
-        return 200, cast(JsonDict, ret)
+        return HTTPStatus.OK, cast(JsonDict, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -413,7 +419,7 @@ class RoomMembersRestServlet(RestServlet):
         members = await self.store.get_users_in_room(room_id)
         ret = {"members": members, "total": len(members)}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class RoomStateRestServlet(RestServlet):
@@ -443,16 +449,10 @@ class RoomStateRestServlet(RestServlet):
         event_ids = await self.store.get_current_state_ids(room_id)
         events = await self.store.get_events(event_ids.values())
         now = self.clock.time_msec()
-        room_state = await self._event_serializer.serialize_events(
-            events.values(),
-            now,
-            # We don't bother bundling aggregations in when asked for state
-            # events, as clients won't use them.
-            bundle_aggregations=False,
-        )
+        room_state = await self._event_serializer.serialize_events(events.values(), now)
         ret = {"state": room_state}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
@@ -481,7 +481,10 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         target_user = UserID.from_string(content["user_id"])
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "This endpoint can only be used with local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "This endpoint can only be used with local users",
+            )
 
         if not await self.admin_handler.get_user(target_user):
             raise NotFoundError("User not found")
@@ -527,7 +530,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
             ratelimit=False,
         )
 
-        return 200, {"room_id": room_id}
+        return HTTPStatus.OK, {"room_id": room_id}
 
 
 class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -568,7 +571,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         # Figure out which local users currently have power in the room, if any.
         room_state = await self.state_handler.get_current_state(room_id)
         if not room_state:
-            raise SynapseError(400, "Server not in room")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
 
         create_event = room_state[(EventTypes.Create, "")]
         power_levels = room_state.get((EventTypes.PowerLevels, ""))
@@ -582,7 +585,9 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_users.sort(key=lambda user: user_power[user])
 
             if not admin_users:
-                raise SynapseError(400, "No local admin user in room")
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST, "No local admin user in room"
+                )
 
             admin_user_id = None
 
@@ -599,7 +604,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
 
             if not admin_user_id:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "No local admin user in room",
                 )
 
@@ -610,7 +615,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_user_id = create_event.sender
             if not self.is_mine_id(admin_user_id):
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "No local admin user in room",
                 )
 
@@ -639,7 +644,8 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         except AuthError:
             # The admin user we found turned out not to have enough power.
             raise SynapseError(
-                400, "No local admin user in room with power to update power levels."
+                HTTPStatus.BAD_REQUEST,
+                "No local admin user in room with power to update power levels.",
             )
 
         # Now we check if the user we're granting admin rights to is already in
@@ -653,7 +659,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             )
 
         if is_joined:
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         join_rules = room_state.get((EventTypes.JoinRules, ""))
         is_public = False
@@ -661,7 +667,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
 
         if is_public:
-            return 200, {}
+            return HTTPStatus.OK, {}
 
         await self.room_member_handler.update_membership(
             fake_requester,
@@ -670,7 +676,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             action=Membership.INVITE,
         )
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
@@ -702,7 +708,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         room_id, _ = await self.resolve_room_id(room_identifier)
 
         deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
-        return 200, {"deleted": deleted_count}
+        return HTTPStatus.OK, {"deleted": deleted_count}
 
     async def on_GET(
         self, request: SynapseRequest, room_identifier: str
@@ -713,7 +719,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         room_id, _ = await self.resolve_room_id(room_identifier)
 
         extremities = await self.store.get_forward_extremities_for_room(room_id)
-        return 200, {"count": len(extremities), "results": extremities}
+        return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
 
 
 class RoomEventContextServlet(RestServlet):
@@ -762,7 +768,9 @@ class RoomEventContextServlet(RestServlet):
         )
 
         if not results:
-            raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+            raise SynapseError(
+                HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
+            )
 
         time_now = self.clock.time_msec()
         results["events_before"] = await self._event_serializer.serialize_events(
@@ -775,10 +783,70 @@ class RoomEventContextServlet(RestServlet):
             results["events_after"], time_now
         )
         results["state"] = await self._event_serializer.serialize_events(
-            results["state"],
-            time_now,
-            # No need to bundle aggregations for state events
-            bundle_aggregations=False,
+            results["state"], time_now
         )
 
-        return 200, results
+        return HTTPStatus.OK, results
+
+
+class BlockRoomRestServlet(RestServlet):
+    """
+    Manage blocking of rooms.
+    On PUT: Add or remove a room from blocking list.
+    On GET: Get blocking status of room and user who has blocked this room.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self._auth, request)
+
+        if not RoomID.is_valid(room_id):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
+
+        blocked_by = await self._store.room_is_blocked_by(room_id)
+        # Test `not None` if `user_id` is an empty string
+        # if someone add manually an entry in database
+        if blocked_by is not None:
+            response = {"block": True, "user_id": blocked_by}
+        else:
+            response = {"block": False}
+
+        return HTTPStatus.OK, response
+
+    async def on_PUT(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester.user)
+
+        content = parse_json_object_from_request(request)
+
+        if not RoomID.is_valid(room_id):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,)
+            )
+
+        assert_params_in_dict(content, ["block"])
+        block = content.get("block")
+        if not isinstance(block, bool):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Param 'block' must be a boolean.",
+                Codes.BAD_JSON,
+            )
+
+        if block:
+            await self._store.block_room(room_id, requester.user.to_string())
+        else:
+            await self._store.unblock_room(room_id)
+
+        return HTTPStatus.OK, {"block": block}
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 19f84f33f2..b295fb078b 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
 
 from synapse.api.constants import EventTypes
@@ -82,11 +83,15 @@ class SendServerNoticeServlet(RestServlet):
         # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
         # admin api).
         if not self.server_notices_manager.is_enabled():
-            raise SynapseError(400, "Server notices are not enabled on this server")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Server notices are not enabled on this server"
+            )
 
         target_user = UserID.from_string(body["user_id"])
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Server notices can only be sent to local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
+            )
 
         if not await self.admin_handler.get_user(target_user):
             raise NotFoundError("User not found")
@@ -99,7 +104,7 @@ class SendServerNoticeServlet(RestServlet):
             txn_id=txn_id,
         )
 
-        return 200, {"event_id": event.event_id}
+        return HTTPStatus.OK, {"event_id": event.event_id}
 
     def on_PUT(
         self, request: SynapseRequest, txn_id: str
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 948de94ccd..ca41fd45f2 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, SynapseError
@@ -53,7 +54,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
             UserSortOrder.DISPLAYNAME.value,
         ):
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Unknown value for order_by: %s" % (order_by,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -61,7 +62,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         start = parse_integer(request, "from", default=0)
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -69,7 +70,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         limit = parse_integer(request, "limit", default=100)
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -77,7 +78,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         from_ts = parse_integer(request, "from_ts", default=0)
         if from_ts < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from_ts must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -86,13 +87,13 @@ class UserMediaStatisticsRestServlet(RestServlet):
         if until_ts is not None:
             if until_ts < 0:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Query parameter until_ts must be a string representing a positive integer.",
                     errcode=Codes.INVALID_PARAM,
                 )
             if until_ts <= from_ts:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Query parameter until_ts must be greater than from_ts.",
                     errcode=Codes.INVALID_PARAM,
                 )
@@ -100,7 +101,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
         search_term = parse_string(request, "search_term")
         if search_term == "":
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter search_term cannot be an empty string.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -108,7 +109,9 @@ class UserMediaStatisticsRestServlet(RestServlet):
         direction = parse_string(request, "dir", default="f")
         if direction not in ("f", "b"):
             raise SynapseError(
-                400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+                HTTPStatus.BAD_REQUEST,
+                "Unknown direction: %s" % (direction,),
+                errcode=Codes.INVALID_PARAM,
             )
 
         users_media, total = await self.store.get_users_media_usage_paginate(
@@ -118,4 +121,4 @@ class UserMediaStatisticsRestServlet(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = start + len(users_media)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 23a8bf1fdb..2a60b602b1 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -79,14 +79,14 @@ class UsersRestServletV2(RestServlet):
 
         if start < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter from must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
 
         if limit < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "Query parameter limit must be a string representing a positive integer.",
                 errcode=Codes.INVALID_PARAM,
             )
@@ -122,7 +122,7 @@ class UsersRestServletV2(RestServlet):
         if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class UserRestServletV2(RestServlet):
@@ -172,14 +172,14 @@ class UserRestServletV2(RestServlet):
 
         target_user = UserID.from_string(user_id)
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         ret = await self.admin_handler.get_user(target_user)
 
         if not ret:
             raise NotFoundError("User not found")
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str
@@ -191,7 +191,10 @@ class UserRestServletV2(RestServlet):
         body = parse_json_object_from_request(request)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "This endpoint can only be used with local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "This endpoint can only be used with local users",
+            )
 
         user = await self.admin_handler.get_user(target_user)
         user_id = target_user.to_string()
@@ -210,7 +213,7 @@ class UserRestServletV2(RestServlet):
 
         user_type = body.get("user_type", None)
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
-            raise SynapseError(400, "Invalid user type")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
 
         set_admin_to = body.get("admin", False)
         if not isinstance(set_admin_to, bool):
@@ -223,11 +226,13 @@ class UserRestServletV2(RestServlet):
         password = body.get("password", None)
         if password is not None:
             if not isinstance(password, str) or len(password) > 512:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
         deactivate = body.get("deactivated", False)
         if not isinstance(deactivate, bool):
-            raise SynapseError(400, "'deactivated' parameter is not of type boolean")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
+            )
 
         # convert List[Dict[str, str]] into List[Tuple[str, str]]
         if external_ids is not None:
@@ -282,7 +287,9 @@ class UserRestServletV2(RestServlet):
                         user_id,
                     )
                 except ExternalIDReuseException:
-                    raise SynapseError(409, "External id is already in use.")
+                    raise SynapseError(
+                        HTTPStatus.CONFLICT, "External id is already in use."
+                    )
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
@@ -293,7 +300,9 @@ class UserRestServletV2(RestServlet):
                 if set_admin_to != user["admin"]:
                     auth_user = requester.user
                     if target_user == auth_user and not set_admin_to:
-                        raise SynapseError(400, "You may not demote yourself.")
+                        raise SynapseError(
+                            HTTPStatus.BAD_REQUEST, "You may not demote yourself."
+                        )
 
                     await self.store.set_server_admin(target_user, set_admin_to)
 
@@ -319,7 +328,8 @@ class UserRestServletV2(RestServlet):
                         and self.auth_handler.can_change_password()
                     ):
                         raise SynapseError(
-                            400, "Must provide a password to re-activate an account."
+                            HTTPStatus.BAD_REQUEST,
+                            "Must provide a password to re-activate an account.",
                         )
 
                     await self.deactivate_account_handler.activate_account(
@@ -332,7 +342,7 @@ class UserRestServletV2(RestServlet):
             user = await self.admin_handler.get_user(target_user)
             assert user is not None
 
-            return 200, user
+            return HTTPStatus.OK, user
 
         else:  # create user
             displayname = body.get("displayname", None)
@@ -381,7 +391,9 @@ class UserRestServletV2(RestServlet):
                             user_id,
                         )
                 except ExternalIDReuseException:
-                    raise SynapseError(409, "External id is already in use.")
+                    raise SynapseError(
+                        HTTPStatus.CONFLICT, "External id is already in use."
+                    )
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
@@ -429,51 +441,61 @@ class UserRegisterServlet(RestServlet):
 
         nonce = secrets.token_hex(64)
         self.nonces[nonce] = int(self.reactor.seconds())
-        return 200, {"nonce": nonce}
+        return HTTPStatus.OK, {"nonce": nonce}
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         self._clear_old_nonces()
 
         if not self.hs.config.registration.registration_shared_secret:
-            raise SynapseError(400, "Shared secret registration is not enabled")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled"
+            )
 
         body = parse_json_object_from_request(request)
 
         if "nonce" not in body:
-            raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "nonce must be specified",
+                errcode=Codes.BAD_JSON,
+            )
 
         nonce = body["nonce"]
 
         if nonce not in self.nonces:
-            raise SynapseError(400, "unrecognised nonce")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "unrecognised nonce")
 
         # Delete the nonce, so it can't be reused, even if it's invalid
         del self.nonces[nonce]
 
         if "username" not in body:
             raise SynapseError(
-                400, "username must be specified", errcode=Codes.BAD_JSON
+                HTTPStatus.BAD_REQUEST,
+                "username must be specified",
+                errcode=Codes.BAD_JSON,
             )
         else:
             if not isinstance(body["username"], str) or len(body["username"]) > 512:
-                raise SynapseError(400, "Invalid username")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
 
             username = body["username"].encode("utf-8")
             if b"\x00" in username:
-                raise SynapseError(400, "Invalid username")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid username")
 
         if "password" not in body:
             raise SynapseError(
-                400, "password must be specified", errcode=Codes.BAD_JSON
+                HTTPStatus.BAD_REQUEST,
+                "password must be specified",
+                errcode=Codes.BAD_JSON,
             )
         else:
             password = body["password"]
             if not isinstance(password, str) or len(password) > 512:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
             password_bytes = password.encode("utf-8")
             if b"\x00" in password_bytes:
-                raise SynapseError(400, "Invalid password")
+                raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
 
             password_hash = await self.auth_handler.hash(password)
 
@@ -482,10 +504,12 @@ class UserRegisterServlet(RestServlet):
         displayname = body.get("displayname", None)
 
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
-            raise SynapseError(400, "Invalid user type")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
 
         if "mac" not in body:
-            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "mac must be specified", errcode=Codes.BAD_JSON
+            )
 
         got_mac = body["mac"]
 
@@ -507,7 +531,7 @@ class UserRegisterServlet(RestServlet):
         want_mac = want_mac_builder.hexdigest()
 
         if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
-            raise SynapseError(403, "HMAC incorrect")
+            raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
 
         # Reuse the parts of RegisterRestServlet to reduce code duplication
         from synapse.rest.client.register import RegisterRestServlet
@@ -524,7 +548,7 @@ class UserRegisterServlet(RestServlet):
         )
 
         result = await register._create_registration_details(user_id, body)
-        return 200, result
+        return HTTPStatus.OK, result
 
 
 class WhoisRestServlet(RestServlet):
@@ -552,11 +576,11 @@ class WhoisRestServlet(RestServlet):
             await assert_user_is_admin(self.auth, auth_user)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only whois a local user")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
 
         ret = await self.admin_handler.get_whois(target_user)
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class DeactivateAccountRestServlet(RestServlet):
@@ -575,7 +599,9 @@ class DeactivateAccountRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
 
         if not self.is_mine(UserID.from_string(target_user_id)):
-            raise SynapseError(400, "Can only deactivate local users")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Can only deactivate local users"
+            )
 
         if not await self.store.get_user_by_id(target_user_id):
             raise NotFoundError("User not found")
@@ -597,7 +623,7 @@ class DeactivateAccountRestServlet(RestServlet):
         else:
             id_server_unbind_result = "no-support"
 
-        return 200, {"id_server_unbind_result": id_server_unbind_result}
+        return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
 
 
 class AccountValidityRenewServlet(RestServlet):
@@ -620,7 +646,7 @@ class AccountValidityRenewServlet(RestServlet):
 
             if "user_id" not in body:
                 raise SynapseError(
-                    400,
+                    HTTPStatus.BAD_REQUEST,
                     "Missing property 'user_id' in the request body",
                 )
 
@@ -631,7 +657,7 @@ class AccountValidityRenewServlet(RestServlet):
             )
 
         res = {"expiration_ts": expiration_ts}
-        return 200, res
+        return HTTPStatus.OK, res
 
 
 class ResetPasswordRestServlet(RestServlet):
@@ -678,7 +704,7 @@ class ResetPasswordRestServlet(RestServlet):
         await self._set_password_handler.set_password(
             target_user_id, new_password_hash, logout_devices, requester
         )
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class SearchUsersRestServlet(RestServlet):
@@ -712,16 +738,16 @@ class SearchUsersRestServlet(RestServlet):
 
         # To allow all users to get the users list
         # if not is_admin and target_user != auth_user:
-        #     raise AuthError(403, "You are not a server admin")
+        #     raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only users a local user")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
 
         term = parse_string(request, "term", required=True)
         logger.info("term: %s ", term)
 
         ret = await self.store.search_users(term)
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class UserAdminServlet(RestServlet):
@@ -765,11 +791,14 @@ class UserAdminServlet(RestServlet):
         target_user = UserID.from_string(user_id)
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Only local users can be admins of this homeserver")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Only local users can be admins of this homeserver",
+            )
 
         is_admin = await self.store.is_server_admin(target_user)
 
-        return 200, {"admin": is_admin}
+        return HTTPStatus.OK, {"admin": is_admin}
 
     async def on_PUT(
         self, request: SynapseRequest, user_id: str
@@ -785,16 +814,19 @@ class UserAdminServlet(RestServlet):
         assert_params_in_dict(body, ["admin"])
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Only local users can be admins of this homeserver")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Only local users can be admins of this homeserver",
+            )
 
         set_admin_to = bool(body["admin"])
 
         if target_user == auth_user and not set_admin_to:
-            raise SynapseError(400, "You may not demote yourself.")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "You may not demote yourself.")
 
         await self.store.set_server_admin(target_user, set_admin_to)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class UserMembershipRestServlet(RestServlet):
@@ -816,7 +848,7 @@ class UserMembershipRestServlet(RestServlet):
 
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
-        return 200, ret
+        return HTTPStatus.OK, ret
 
 
 class PushersRestServlet(RestServlet):
@@ -845,7 +877,7 @@ class PushersRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -854,7 +886,10 @@ class PushersRestServlet(RestServlet):
 
         filtered_pushers = [p.as_dict() for p in pushers]
 
-        return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+        return HTTPStatus.OK, {
+            "pushers": filtered_pushers,
+            "total": len(filtered_pushers),
+        }
 
 
 class UserTokenRestServlet(RestServlet):
@@ -887,25 +922,31 @@ class UserTokenRestServlet(RestServlet):
         auth_user = requester.user
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be logged in as")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
+            )
 
         body = parse_json_object_from_request(request, allow_empty_body=True)
 
         valid_until_ms = body.get("valid_until_ms")
         if valid_until_ms and not isinstance(valid_until_ms, int):
-            raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
+            )
 
         if auth_user.to_string() == user_id:
-            raise SynapseError(400, "Cannot use admin API to login as self")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Cannot use admin API to login as self"
+            )
 
-        token = await self.auth_handler.get_access_token_for_user_id(
+        token = await self.auth_handler.create_access_token_for_user_id(
             user_id=auth_user.to_string(),
             device_id=None,
             valid_until_ms=valid_until_ms,
             puppets_user_id=user_id,
         )
 
-        return 200, {"access_token": token}
+        return HTTPStatus.OK, {"access_token": token}
 
 
 class ShadowBanRestServlet(RestServlet):
@@ -947,11 +988,13 @@ class ShadowBanRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be shadow-banned")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+            )
 
         await self.store.set_shadow_banned(UserID.from_string(user_id), True)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -959,11 +1002,13 @@ class ShadowBanRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be shadow-banned")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
+            )
 
         await self.store.set_shadow_banned(UserID.from_string(user_id), False)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
 
 
 class RateLimitRestServlet(RestServlet):
@@ -995,7 +1040,7 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Can only look up local users")
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -1016,7 +1061,7 @@ class RateLimitRestServlet(RestServlet):
         else:
             ret = {}
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -1024,7 +1069,9 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be ratelimited")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+            )
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
@@ -1036,14 +1083,14 @@ class RateLimitRestServlet(RestServlet):
 
         if not isinstance(messages_per_second, int) or messages_per_second < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (messages_per_second,),
                 errcode=Codes.INVALID_PARAM,
             )
 
         if not isinstance(burst_count, int) or burst_count < 0:
             raise SynapseError(
-                400,
+                HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (burst_count,),
                 errcode=Codes.INVALID_PARAM,
             )
@@ -1059,7 +1106,7 @@ class RateLimitRestServlet(RestServlet):
             "burst_count": ratelimit.burst_count,
         }
 
-        return 200, ret
+        return HTTPStatus.OK, ret
 
     async def on_DELETE(
         self, request: SynapseRequest, user_id: str
@@ -1067,11 +1114,13 @@ class RateLimitRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine_id(user_id):
-            raise SynapseError(400, "Only local users can be ratelimited")
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
+            )
 
         if not await self.store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
 
         await self.store.delete_ratelimit_for_user(user_id)
 
-        return 200, {}
+        return HTTPStatus.OK, {}
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index a0971ce994..b4cb90cb76 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
 
 def client_patterns(
     path_regex: str,
-    releases: Iterable[int] = (0,),
+    releases: Iterable[str] = ("r0", "v3"),
     unstable: bool = True,
     v1: bool = False,
 ) -> Iterable[Pattern]:
@@ -52,7 +52,7 @@ def client_patterns(
         v1_prefix = CLIENT_API_PREFIX + "/api/v1"
         patterns.append(re.compile("^" + v1_prefix + path_regex))
     for release in releases:
-        new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
+        new_prefix = CLIENT_API_PREFIX + f"/{release}"
         patterns.append(re.compile("^" + new_prefix + path_regex))
 
     return patterns
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 7281b2ee29..730c18f08f 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet):
     }
     """
 
-    PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
+    PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",))
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 467444a041..f9994658c4 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -14,7 +14,17 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from typing_extensions import TypedDict
 
@@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -63,7 +72,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "m.login.application_service"
     APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service"
-    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
+    REFRESH_TOKEN_PARAM = "refresh_token"
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -72,6 +81,7 @@ class LoginRestServlet(RestServlet):
         # JWT configuration variables.
         self.jwt_enabled = hs.config.jwt.jwt_enabled
         self.jwt_secret = hs.config.jwt.jwt_secret
+        self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
         self.jwt_algorithm = hs.config.jwt.jwt_algorithm
         self.jwt_issuer = hs.config.jwt.jwt_issuer
         self.jwt_audiences = hs.config.jwt.jwt_audiences
@@ -80,7 +90,9 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2.saml2_enabled
         self.cas_enabled = hs.config.cas.cas_enabled
         self.oidc_enabled = hs.config.oidc.oidc_enabled
-        self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
+        self._refresh_tokens_enabled = (
+            hs.config.registration.refreshable_access_token_lifetime is not None
+        )
 
         self.auth = hs.get_auth()
 
@@ -151,14 +163,16 @@ class LoginRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
         login_submission = parse_json_object_from_request(request)
 
-        if self._msc2918_enabled:
-            # Check if this login should also issue a refresh token, as per
-            # MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
-            )
-        else:
-            should_issue_refresh_token = False
+        # Check to see if the client requested a refresh token.
+        client_requested_refresh_token = login_submission.get(
+            LoginRestServlet.REFRESH_TOKEN_PARAM, False
+        )
+        if not isinstance(client_requested_refresh_token, bool):
+            raise SynapseError(400, "`refresh_token` should be true or false.")
+
+        should_issue_refresh_token = (
+            self._refresh_tokens_enabled and client_requested_refresh_token
+        )
 
         try:
             if login_submission["type"] in (
@@ -288,6 +302,7 @@ class LoginRestServlet(RestServlet):
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -303,10 +318,10 @@ class LoginRestServlet(RestServlet):
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
             ratelimit: Whether to ratelimit the login request.
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: True if this login should issue
                 a refresh token alongside the access token.
+            auth_provider_session_id: The session ID got during login from the SSO IdP.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -339,6 +354,7 @@ class LoginRestServlet(RestServlet):
             initial_display_name,
             auth_provider_id=auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         result = LoginResponse(
@@ -384,6 +400,7 @@ class LoginRestServlet(RestServlet):
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=res.auth_provider_session_id,
         )
 
     async def _do_jwt_login(
@@ -413,7 +430,7 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get("sub", None)
+        user = payload.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
 
@@ -445,14 +462,15 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
 
 
 class RefreshTokenServlet(RestServlet):
-    PATTERNS = client_patterns(
-        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
-    )
+    PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),)
 
     def __init__(self, hs: "HomeServer"):
         self._auth_handler = hs.get_auth_handler()
         self._clock = hs.get_clock()
-        self.access_token_lifetime = hs.config.registration.access_token_lifetime
+        self.refreshable_access_token_lifetime = (
+            hs.config.registration.refreshable_access_token_lifetime
+        )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         refresh_submission = parse_json_object_from_request(request)
@@ -462,27 +480,40 @@ class RefreshTokenServlet(RestServlet):
         if not isinstance(token, str):
             raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
 
-        valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
-        access_token, refresh_token = await self._auth_handler.refresh_token(
-            token, valid_until_ms
-        )
-        expires_in_ms = valid_until_ms - self._clock.time_msec()
-        return (
-            200,
-            {
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "expires_in_ms": expires_in_ms,
-            },
+        now = self._clock.time_msec()
+        access_valid_until_ms = None
+        if self.refreshable_access_token_lifetime is not None:
+            access_valid_until_ms = now + self.refreshable_access_token_lifetime
+        refresh_valid_until_ms = None
+        if self.refresh_token_lifetime is not None:
+            refresh_valid_until_ms = now + self.refresh_token_lifetime
+
+        (
+            access_token,
+            refresh_token,
+            actual_access_token_expiry,
+        ) = await self._auth_handler.refresh_token(
+            token, access_valid_until_ms, refresh_valid_until_ms
         )
 
+        response: Dict[str, Union[str, int]] = {
+            "access_token": access_token,
+            "refresh_token": refresh_token,
+        }
+
+        # expires_in_ms is only present if the token expires
+        if actual_access_token_expiry is not None:
+            response["expires_in_ms"] = actual_access_token_expiry - now
+
+        return 200, response
+
 
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
         re.compile(
             "^"
             + CLIENT_API_PREFIX
-            + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+            + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
         )
     ]
 
@@ -561,7 +592,7 @@ class CasTicketServlet(RestServlet):
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     LoginRestServlet(hs).register(http_server)
-    if hs.config.registration.access_token_lifetime is not None:
+    if hs.config.registration.refreshable_access_token_lifetime is not None:
         RefreshTokenServlet(hs).register(http_server)
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas.cas_enabled:
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index bf3cb34146..8b56c76aed 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_json_object_from_request,
     parse_string,
 )
@@ -420,7 +419,9 @@ class RegisterRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
         self._registration_enabled = self.hs.config.registration.enable_registration
-        self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None
+        self._refresh_tokens_enabled = (
+            hs.config.registration.refreshable_access_token_lifetime is not None
+        )
 
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
@@ -444,14 +445,15 @@ class RegisterRestServlet(RestServlet):
                 f"Do not understand membership kind: {kind}",
             )
 
-        if self._msc2918_enabled:
-            # Check if this registration should also issue a refresh token, as
-            # per MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name="org.matrix.msc2918.refresh_token", default=False
-            )
-        else:
-            should_issue_refresh_token = False
+        # Check if the clients wishes for this registration to issue a refresh
+        # token.
+        client_requested_refresh_tokens = body.get("refresh_token", False)
+        if not isinstance(client_requested_refresh_tokens, bool):
+            raise SynapseError(400, "`refresh_token` should be true or false.")
+
+        should_issue_refresh_token = (
+            self._refresh_tokens_enabled and client_requested_refresh_tokens
+        )
 
         # Pull out the provided username and do basic sanity checks early since
         # the auth layer will store these in sessions.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 184cfbe196..fc4e6921c5 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -224,18 +224,14 @@ class RelationPaginationServlet(RestServlet):
         )
 
         now = self.clock.time_msec()
-        # We set bundle_aggregations to False when retrieving the original
-        # event because we want the content before relations were applied to
-        # it.
+        # Do not bundle aggregations when retrieving the original event because
+        # we want the content before relations are applied to it.
         original_event = await self._event_serializer.serialize_event(
             event, now, bundle_aggregations=False
         )
-        # Similarly, we don't allow relations to be applied to relations, so we
-        # return the original relations without any aggregations on top of them
-        # here.
-        serialized_events = await self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=False
-        )
+        # The relations returned for the requested event do include their
+        # bundled aggregations.
+        serialized_events = await self._event_serializer.serialize_events(events, now)
 
         return_value = pagination_chunk.to_dict()
         return_value["chunk"] = serialized_events
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 03a353d53c..f48e2e6ca2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet):
             results["events_after"], time_now
         )
         results["state"] = await self._event_serializer.serialize_events(
-            results["state"],
-            time_now,
-            # No need to bundle aggregations for state events
-            bundle_aggregations=False,
+            results["state"], time_now
         )
 
         return 200, results
@@ -1070,6 +1067,62 @@ def register_txn_path(
         )
 
 
+class TimestampLookupRestServlet(RestServlet):
+    """
+    API endpoint to fetch the `event_id` of the closest event to the given
+    timestamp (`ts` query parameter) in the given direction (`dir` query
+    parameter).
+
+    Useful for cases like jump to date so you can start paginating messages from
+    a given date in the archive.
+
+    `ts` is a timestamp in milliseconds where we will find the closest event in
+    the given direction.
+
+    `dir` can be `f` or `b` to indicate forwards and backwards in time from the
+    given timestamp.
+
+    GET /_matrix/client/unstable/org.matrix.msc3030/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>
+    {
+        "event_id": ...
+    }
+    """
+
+    PATTERNS = (
+        re.compile(
+            "^/_matrix/client/unstable/org.matrix.msc3030"
+            "/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"
+        ),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+        self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await self._auth.check_user_in_room(room_id, requester.user.to_string())
+
+        timestamp = parse_integer(request, "ts", required=True)
+        direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+
+        (
+            event_id,
+            origin_server_ts,
+        ) = await self.timestamp_lookup_handler.get_event_for_timestamp(
+            requester, room_id, timestamp, direction
+        )
+
+        return 200, {
+            "event_id": event_id,
+            "origin_server_ts": origin_server_ts,
+        }
+
+
 class RoomSpaceSummaryRestServlet(RestServlet):
     PATTERNS = (
         re.compile(
@@ -1140,7 +1193,7 @@ class RoomSpaceSummaryRestServlet(RestServlet):
 class RoomHierarchyRestServlet(RestServlet):
     PATTERNS = (
         re.compile(
-            "^/_matrix/client/unstable/org.matrix.msc2946"
+            "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
             "/rooms/(?P<room_id>[^/]*)/hierarchy$"
         ),
     )
@@ -1168,7 +1221,7 @@ class RoomHierarchyRestServlet(RestServlet):
             )
 
         return 200, await self._room_summary_handler.get_room_hierarchy(
-            requester.user.to_string(),
+            requester,
             room_id,
             suggested_only=parse_boolean(request, "suggested_only", default=False),
             max_depth=max_depth,
@@ -1239,6 +1292,8 @@ def register_servlets(
     RoomAliasListServlet(hs).register(http_server)
     SearchRestServlet(hs).register(http_server)
     RoomCreateRestServlet(hs).register(http_server)
+    if hs.config.experimental.msc3030_enabled:
+        TimestampLookupRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8c0fdb1940..88e4f5e063 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet):
             return self._event_serializer.serialize_events(
                 events,
                 time_now=time_now,
-                # We don't bundle "live" events, as otherwise clients
-                # will end up double counting annotations.
-                bundle_aggregations=False,
+                # Don't bother to bundle aggregations if the timeline is unlimited,
+                # as clients will have all the necessary information.
+                bundle_aggregations=room.timeline.limited,
                 token_id=token_id,
                 event_format=event_formatter,
                 only_event_fields=only_fields,
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 014fa893d6..9b40fd8a6c 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError, cs_error
 from synapse.http.server import finish_request, respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util.stringutils import is_ascii
+from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
 
 logger = logging.getLogger(__name__)
 
@@ -51,6 +51,19 @@ TEXT_CONTENT_TYPES = [
 
 
 def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
+    """Parses the server name, media ID and optional file name from the request URI
+
+    Also performs some rough validation on the server name.
+
+    Args:
+        request: The `Request`.
+
+    Returns:
+        A tuple containing the parsed server name, media ID and optional file name.
+
+    Raises:
+        SynapseError(404): if parsing or validation fail for any reason
+    """
     try:
         # The type on postpath seems incorrect in Twisted 21.2.0.
         postpath: List[bytes] = request.postpath  # type: ignore
@@ -62,6 +75,9 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
         server_name = server_name_bytes.decode("utf-8")
         media_id = media_id_bytes.decode("utf8")
 
+        # Validate the server name, raising if invalid
+        parse_and_validate_server_name(server_name)
+
         file_name = None
         if len(postpath) > 2:
             try:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index bec77088ee..1f6441c412 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,8 @@
 import functools
 import os
 import re
-from typing import Any, Callable, List, TypeVar, cast
+import string
+from typing import Any, Callable, List, TypeVar, Union, cast
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
@@ -37,6 +38,113 @@ def _wrap_in_base_path(func: F) -> F:
     return cast(F, _wrapped)
 
 
+GetPathMethod = TypeVar(
+    "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
+)
+
+
+def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]:
+    """Wraps a path-returning method to check that the returned path(s) do not escape
+    the media store directory.
+
+    The path-returning method may return either a single path, or a list of paths.
+
+    The check is not expected to ever fail, unless `func` is missing a call to
+    `_validate_path_component`, or `_validate_path_component` is buggy.
+
+    Args:
+        relative: A boolean indicating whether the wrapped method returns paths relative
+            to the media store directory.
+
+    Returns:
+        A method which will wrap a path-returning method, adding a check to ensure that
+        the returned path(s) lie within the media store directory. The check will raise
+        a `ValueError` if it fails.
+    """
+
+    def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod:
+        @functools.wraps(func)
+        def _wrapped(
+            self: "MediaFilePaths", *args: Any, **kwargs: Any
+        ) -> Union[str, List[str]]:
+            path_or_paths = func(self, *args, **kwargs)
+
+            if isinstance(path_or_paths, list):
+                paths_to_check = path_or_paths
+            else:
+                paths_to_check = [path_or_paths]
+
+            for path in paths_to_check:
+                # Construct the path that will ultimately be used.
+                # We cannot guess whether `path` is relative to the media store
+                # directory, since the media store directory may itself be a relative
+                # path.
+                if relative:
+                    path = os.path.join(self.base_path, path)
+                normalized_path = os.path.normpath(path)
+
+                # Now that `normpath` has eliminated `../`s and `./`s from the path,
+                # `os.path.commonpath` can be used to check whether it lies within the
+                # media store directory.
+                if (
+                    os.path.commonpath([normalized_path, self.normalized_base_path])
+                    != self.normalized_base_path
+                ):
+                    # The path resolves to outside the media store directory,
+                    # or `self.base_path` is `.`, which is an unlikely configuration.
+                    raise ValueError(f"Invalid media store path: {path!r}")
+
+                # Note that `os.path.normpath`/`abspath` has a subtle caveat:
+                # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a
+                # different path if `a/b/c` is a symlink. That is, the check above is
+                # not perfect and may allow a certain restricted subset of untrustworthy
+                # paths through. Since the check above is secondary to the main
+                # `_validate_path_component` checks, it's less important for it to be
+                # perfect.
+                #
+                # As an alternative, `os.path.realpath` will resolve symlinks, but
+                # proves problematic if there are symlinks inside the media store.
+                # eg. if `url_store/` is symlinked to elsewhere, its canonical path
+                # won't match that of the main media store directory.
+
+            return path_or_paths
+
+        return cast(GetPathMethod, _wrapped)
+
+    return _wrap_with_jail_check_inner
+
+
+ALLOWED_CHARACTERS = set(
+    string.ascii_letters
+    + string.digits
+    + "_-"
+    + ".[]:"  # Domain names, IPv6 addresses and ports in server names
+)
+FORBIDDEN_NAMES = {
+    "",
+    os.path.curdir,  # "." for the current platform
+    os.path.pardir,  # ".." for the current platform
+}
+
+
+def _validate_path_component(name: str) -> str:
+    """Checks that the given string can be safely used as a path component
+
+    Args:
+        name: The path component to check.
+
+    Returns:
+        The path component if valid.
+
+    Raises:
+        ValueError: If `name` cannot be safely used as a path component.
+    """
+    if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
+        raise ValueError(f"Invalid path component: {name!r}")
+
+    return name
+
+
 class MediaFilePaths:
     """Describes where files are stored on disk.
 
@@ -47,23 +155,45 @@ class MediaFilePaths:
 
     def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
-
+        self.normalized_base_path = os.path.normpath(self.base_path)
+
+        # Refuse to initialize if paths cannot be validated correctly for the current
+        # platform.
+        assert os.path.sep not in ALLOWED_CHARACTERS
+        assert os.path.altsep not in ALLOWED_CHARACTERS
+        # On Windows, paths have all sorts of weirdness which `_validate_path_component`
+        # does not consider. In any case, the remote media store can't work correctly
+        # for certain homeservers there, since ":"s aren't allowed in paths.
+        assert os.name == "posix"
+
+    @_wrap_with_jail_check(relative=True)
     def local_media_filepath_rel(self, media_id: str) -> str:
-        return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
+        return os.path.join(
+            "local_content",
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
+        )
 
     local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
 
+    @_wrap_with_jail_check(relative=True)
     def local_media_thumbnail_rel(
         self, media_id: str, width: int, height: int, content_type: str, method: str
     ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
-            "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
+            "local_thumbnails",
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
+            _validate_path_component(file_name),
         )
 
     local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
 
+    @_wrap_with_jail_check(relative=False)
     def local_media_thumbnail_dir(self, media_id: str) -> str:
         """
         Retrieve the local store path of thumbnails of a given media_id
@@ -76,18 +206,24 @@ class MediaFilePaths:
         return os.path.join(
             self.base_path,
             "local_thumbnails",
-            media_id[0:2],
-            media_id[2:4],
-            media_id[4:],
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
         )
 
+    @_wrap_with_jail_check(relative=True)
     def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
         return os.path.join(
-            "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
+            "remote_content",
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
         )
 
     remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
 
+    @_wrap_with_jail_check(relative=True)
     def remote_media_thumbnail_rel(
         self,
         server_name: str,
@@ -101,11 +237,11 @@ class MediaFilePaths:
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
             "remote_thumbnail",
-            server_name,
-            file_id[0:2],
-            file_id[2:4],
-            file_id[4:],
-            file_name,
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+            _validate_path_component(file_name),
         )
 
     remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
@@ -113,6 +249,7 @@ class MediaFilePaths:
     # Legacy path that was used to store thumbnails previously.
     # Should be removed after some time, when most of the thumbnails are stored
     # using the new path.
+    @_wrap_with_jail_check(relative=True)
     def remote_media_thumbnail_rel_legacy(
         self, server_name: str, file_id: str, width: int, height: int, content_type: str
     ) -> str:
@@ -120,43 +257,67 @@ class MediaFilePaths:
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
         return os.path.join(
             "remote_thumbnail",
-            server_name,
-            file_id[0:2],
-            file_id[2:4],
-            file_id[4:],
-            file_name,
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+            _validate_path_component(file_name),
         )
 
+    @_wrap_with_jail_check(relative=False)
     def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             self.base_path,
             "remote_thumbnail",
-            server_name,
-            file_id[0:2],
-            file_id[2:4],
-            file_id[4:],
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
         )
 
+    @_wrap_with_jail_check(relative=True)
     def url_cache_filepath_rel(self, media_id: str) -> str:
         if NEW_FORMAT_ID_RE.match(media_id):
             # Media id is of the form <DATE><RANDOM_STRING>
             # E.g.: 2017-09-28-fsdRDt24DS234dsf
-            return os.path.join("url_cache", media_id[:10], media_id[11:])
+            return os.path.join(
+                "url_cache",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+            )
         else:
-            return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
+            return os.path.join(
+                "url_cache",
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
+            )
 
     url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
 
+    @_wrap_with_jail_check(relative=False)
     def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id file"
         if NEW_FORMAT_ID_RE.match(media_id):
-            return [os.path.join(self.base_path, "url_cache", media_id[:10])]
+            return [
+                os.path.join(
+                    self.base_path, "url_cache", _validate_path_component(media_id[:10])
+                )
+            ]
         else:
             return [
-                os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
-                os.path.join(self.base_path, "url_cache", media_id[0:2]),
+                os.path.join(
+                    self.base_path,
+                    "url_cache",
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                ),
+                os.path.join(
+                    self.base_path, "url_cache", _validate_path_component(media_id[0:2])
+                ),
             ]
 
+    @_wrap_with_jail_check(relative=True)
     def url_cache_thumbnail_rel(
         self, media_id: str, width: int, height: int, content_type: str, method: str
     ) -> str:
@@ -168,37 +329,46 @@ class MediaFilePaths:
 
         if NEW_FORMAT_ID_RE.match(media_id):
             return os.path.join(
-                "url_cache_thumbnails", media_id[:10], media_id[11:], file_name
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+                _validate_path_component(file_name),
             )
         else:
             return os.path.join(
                 "url_cache_thumbnails",
-                media_id[0:2],
-                media_id[2:4],
-                media_id[4:],
-                file_name,
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
+                _validate_path_component(file_name),
             )
 
     url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
 
+    @_wrap_with_jail_check(relative=True)
     def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
         if NEW_FORMAT_ID_RE.match(media_id):
-            return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:])
+            return os.path.join(
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+            )
         else:
             return os.path.join(
                 "url_cache_thumbnails",
-                media_id[0:2],
-                media_id[2:4],
-                media_id[4:],
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
             )
 
     url_cache_thumbnail_directory = _wrap_in_base_path(
         url_cache_thumbnail_directory_rel
     )
 
+    @_wrap_with_jail_check(relative=False)
     def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id thumbnails"
         # Media id is of the form <DATE><RANDOM_STRING>
@@ -206,21 +376,35 @@ class MediaFilePaths:
         if NEW_FORMAT_ID_RE.match(media_id):
             return [
                 os.path.join(
-                    self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[:10]),
+                    _validate_path_component(media_id[11:]),
+                ),
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[:10]),
                 ),
-                os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
             ]
         else:
             return [
                 os.path.join(
                     self.base_path,
                     "url_cache_thumbnails",
-                    media_id[0:2],
-                    media_id[2:4],
-                    media_id[4:],
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                    _validate_path_component(media_id[4:]),
                 ),
                 os.path.join(
-                    self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                ),
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[0:2]),
                 ),
-                os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
             ]
diff --git a/synapse/server.py b/synapse/server.py
index 877eba6c08..185e40e4da 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -97,6 +97,7 @@ from synapse.handlers.room import (
     RoomContextHandler,
     RoomCreationHandler,
     RoomShutdownHandler,
+    TimestampLookupHandler,
 )
 from synapse.handlers.room_batch import RoomBatchHandler
 from synapse.handlers.room_list import RoomListHandler
@@ -729,6 +730,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return RoomContextHandler(self)
 
     @cache_in_self
+    def get_timestamp_lookup_handler(self) -> TimestampLookupHandler:
+        return TimestampLookupHandler(self)
+
+    @cache_in_self
     def get_registration_handler(self) -> RegistrationHandler:
         return RegistrationHandler(self)
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
     store: "DataStore"
 
     def get_events(
-        self, event_ids: Iterable[str], allow_rejected: bool = False
+        self, event_ids: Collection[str], allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
 from typing import (
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+    state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
 ) -> StateMap[str]:
     """
     Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         self,
         stream_name: str,
         instance_name: str,
-        token: StreamToken,
+        token: int,
         rows: Iterable[Any],
     ) -> None:
         pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b9a8ca997e..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+)
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.types import Connection
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
 
 from . import engines
 
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+    """A handler for a given background update.
+
+    Attributes:
+        callback: The function to call to make progress on the background
+            update.
+        oneshot: Wether the update is likely to happen all in one go, ignoring
+            the supplied target duration, e.g. index creation. This is used by
+            the update controller to help correctly schedule the update.
+    """
+
+    callback: Callable[[JsonDict, int], Awaitable[int]]
+    oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+    BACKGROUND_UPDATE_INTERVAL_MS = 1000
+    BACKGROUND_UPDATE_DURATION_MS = 100
+
+    def __init__(self, sleep: bool, clock: Clock):
+        self._sleep = sleep
+        self._clock = clock
+
+    async def __aenter__(self) -> int:
+        if self._sleep:
+            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+        return self.BACKGROUND_UPDATE_DURATION_MS
+
+    async def __aexit__(self, *exc) -> None:
+        pass
+
+
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
@@ -82,22 +131,24 @@ class BackgroundUpdater:
     process and autotuning the batch size.
     """
 
-    MINIMUM_BACKGROUND_BATCH_SIZE = 100
+    MINIMUM_BACKGROUND_BATCH_SIZE = 1
     DEFAULT_BACKGROUND_BATCH_SIZE = 100
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
 
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
+        self._database_name = database.name()
+
         # if a background update is currently running, its name.
         self._current_background_update: Optional[str] = None
 
+        self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+        self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+        self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
         self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
-        self._background_update_handlers: Dict[
-            str, Callable[[JsonDict, int], Awaitable[int]]
-        ] = {}
+        self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
         self._all_done = False
 
         # Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+    def register_update_controller_callbacks(
+        self,
+        on_update: ON_UPDATE_CALLBACK,
+        default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+        min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+    ) -> None:
+        """Register callbacks from a module for each hook."""
+        if self._on_update_callback is not None:
+            logger.warning(
+                "More than one module tried to register callbacks for controlling"
+                " background updates. Only the callbacks registered by the first module"
+                " (in order of appearance in Synapse's configuration file) that tried to"
+                " do so will be called."
+            )
+
+            return
+
+        self._on_update_callback = on_update
+
+        if default_batch_size is not None:
+            self._default_batch_size_callback = default_batch_size
+
+        if min_batch_size is not None:
+            self._min_batch_size_callback = min_batch_size
+
+    def _get_context_manager_for_update(
+        self,
+        sleep: bool,
+        update_name: str,
+        database_name: str,
+        oneshot: bool,
+    ) -> AsyncContextManager[int]:
+        """Get a context manager to run a background update with.
+
+        If a module has registered a `update_handler` callback, use the context manager
+        it returns.
+
+        Otherwise, returns a context manager that will return a default value, optionally
+        sleeping if needed.
+
+        Args:
+            sleep: Whether we can sleep between updates.
+            update_name: The name of the update.
+            database_name: The name of the database the update is being run on.
+            oneshot: Whether the update will complete all in one go, e.g. index creation.
+                In such cases the returned target duration is ignored.
+
+        Returns:
+            The target duration in milliseconds that the background update should run for.
+
+            Note: this is a *target*, and an iteration may take substantially longer or
+            shorter.
+        """
+        if self._on_update_callback is not None:
+            return self._on_update_callback(update_name, database_name, oneshot)
+
+        return _BackgroundUpdateContextManager(sleep, self._clock)
+
+    async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+        """The batch size to use for the first iteration of a new background
+        update.
+        """
+        if self._default_batch_size_callback is not None:
+            return await self._default_batch_size_callback(update_name, database_name)
+
+        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+    async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+        """A lower bound on the batch size of a new background update.
+
+        Used to ensure that progress is always made. Must be greater than 0.
+        """
+        if self._min_batch_size_callback is not None:
+            return await self._min_batch_size_callback(update_name, database_name)
+
+        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
 
@@ -122,6 +250,8 @@ class BackgroundUpdater:
 
     def start_doing_background_updates(self) -> None:
         if self.enabled:
+            # if we start a new background update, not all updates are done.
+            self._all_done = False
             run_as_background_process("background_updates", self.run_background_updates)
 
     async def run_background_updates(self, sleep: bool = True) -> None:
@@ -133,13 +263,8 @@ class BackgroundUpdater:
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
-                if sleep:
-                    await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
                 try:
-                    result = await self.do_next_background_update(
-                        self.BACKGROUND_UPDATE_DURATION_MS
-                    )
+                    result = await self.do_next_background_update(sleep)
                 except Exception:
                     logger.exception("Error doing update")
                 else:
@@ -201,13 +326,15 @@ class BackgroundUpdater:
 
         return not update_exists
 
-    async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+    async def do_next_background_update(self, sleep: bool = True) -> bool:
         """Does some amount of work on the next queued background update
 
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms: How long we want to spend updating.
+            sleep: Whether to limit how quickly we run background updates or
+                not.
+
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -250,7 +377,19 @@ class BackgroundUpdater:
 
             self._current_background_update = upd["update_name"]
 
-        await self._do_background_update(desired_duration_ms)
+        # We have a background update to run, otherwise we would have returned
+        # early.
+        assert self._current_background_update is not None
+        update_info = self._background_update_handlers[self._current_background_update]
+
+        async with self._get_context_manager_for_update(
+            sleep=sleep,
+            update_name=self._current_background_update,
+            database_name=self._database_name,
+            oneshot=update_info.oneshot,
+        ) as desired_duration_ms:
+            await self._do_background_update(desired_duration_ms)
+
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -258,7 +397,7 @@ class BackgroundUpdater:
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
-        update_handler = self._background_update_handlers[update_name]
+        update_handler = self._background_update_handlers[update_name].callback
 
         performance = self._background_update_performance.get(update_name)
 
@@ -271,9 +410,14 @@ class BackgroundUpdater:
         if items_per_ms is not None:
             batch_size = int(desired_duration_ms * items_per_ms)
             # Clamp the batch size so that we always make progress
-            batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+            batch_size = max(
+                batch_size,
+                await self._min_batch_size(update_name, self._database_name),
+            )
         else:
-            batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+            batch_size = await self._default_batch_size(
+                update_name, self._database_name
+            )
 
         progress_json = await self.db_pool.simple_select_one_onecol(
             "background_updates",
@@ -292,6 +436,8 @@ class BackgroundUpdater:
 
         duration_ms = time_stop - time_start
 
+        performance.update(items_updated, duration_ms)
+
         logger.info(
             "Running background update %r. Processed %r items in %rms."
             " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -304,8 +450,6 @@ class BackgroundUpdater:
             batch_size,
         )
 
-        performance.update(items_updated, duration_ms)
-
         return len(self._background_update_performance)
 
     def register_background_update_handler(
@@ -329,7 +473,9 @@ class BackgroundUpdater:
             update_name: The name of the update that this code handles.
             update_handler: The function that does the update.
         """
-        self._background_update_handlers[update_name] = update_handler
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            update_handler
+        )
 
     def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
@@ -451,7 +597,9 @@ class BackgroundUpdater:
             await self._end_background_update(update_name)
             return 1
 
-        self.register_background_update_handler(update_name, updater)
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d4cab69ebf..0693d39006 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
 
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
 
 
 R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
+    def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
         self.after_callbacks.append((callback, args, kwargs))
 
     def call_on_exception(
-        self, callback: Callable[..., None], *args: Any, **kwargs: Any
+        self, callback: Callable[..., object], *args: Any, **kwargs: Any
     ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index b88e6e1a75..68ba330432 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
             A list of ApplicationServices, which may be empty.
         """
         results = await self.db_pool.simple_select_list(
-            "application_services_state", {"state": state}, ["as_id"]
+            "application_services_state", {"state": state.value}, ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
         as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
             desc="get_appservice_state",
         )
         if result:
-            return result.get("state")
+            return ApplicationServiceState(result.get("state"))
         return None
 
     async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
             state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
-            "application_services_state", {"as_id": service.id}, {"state": state}
+            "application_services_state", {"as_id": service.id}, {"state": state.value}
         )
 
     async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 554c7a549d..d2b285e852 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -673,6 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
     REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
     REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
+    REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
 
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -688,14 +689,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
-        self.db_pool.updates.register_background_update_handler(
-            self.REMOVE_DELETED_DEVICES,
-            self._remove_deleted_devices_from_device_inbox,
+        # Used to be a background update that deletes all device_inboxes for deleted
+        # devices.
+        self.db_pool.updates.register_noop_background_update(
+            self.REMOVE_DELETED_DEVICES
         )
+        # Used to be a background update that deletes all device_inboxes for hidden
+        # devices.
+        self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
 
         self.db_pool.updates.register_background_update_handler(
-            self.REMOVE_HIDDEN_DEVICES,
-            self._remove_hidden_devices_from_device_inbox,
+            self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+            self._remove_dead_devices_from_device_inbox,
         )
 
     async def _background_drop_index_device_inbox(self, progress, batch_size):
@@ -710,171 +715,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
-    async def _remove_deleted_devices_from_device_inbox(
-        self, progress: JsonDict, batch_size: int
+    async def _remove_dead_devices_from_device_inbox(
+        self,
+        progress: JsonDict,
+        batch_size: int,
     ) -> int:
-        """A background update that deletes all device_inboxes for deleted devices.
-
-        This should only need to be run once (when users upgrade to v1.47.0)
+        """A background update to remove devices that were either deleted or hidden from
+        the device_inbox table.
 
         Args:
-            progress: JsonDict used to store progress of this background update
-            batch_size: the maximum number of rows to retrieve in a single select query
+            progress: The update's progress dict.
+            batch_size: The batch size for this update.
 
         Returns:
-            The number of deleted rows
+            The number of rows deleted.
         """
 
-        def _remove_deleted_devices_from_device_inbox_txn(
+        def _remove_dead_devices_from_device_inbox_txn(
             txn: LoggingTransaction,
-        ) -> int:
-            """stream_id is not unique
-            we need to use an inclusive `stream_id >= ?` clause,
-            since we might not have deleted all dead device messages for the stream_id
-            returned from the previous query
-
-            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
-            to avoid problems of deleting a large number of rows all at once
-            due to a single device having lots of device messages.
-            """
-
-            last_stream_id = progress.get("stream_id", 0)
-
-            sql = """
-                SELECT device_id, user_id, stream_id
-                FROM device_inbox
-                WHERE
-                    stream_id >= ?
-                    AND (device_id, user_id) NOT IN (
-                        SELECT device_id, user_id FROM devices
-                    )
-                ORDER BY stream_id
-                LIMIT ?
-            """
-
-            txn.execute(sql, (last_stream_id, batch_size))
-            rows = txn.fetchall()
-
-            num_deleted = 0
-            for row in rows:
-                num_deleted += self.db_pool.simple_delete_txn(
-                    txn,
-                    "device_inbox",
-                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
-                )
-
-            if rows:
-                # send more than stream_id to progress
-                # otherwise it can happen in large deployments that
-                # no change of status is visible in the log file
-                # it may be that the stream_id does not change in several runs
-                self.db_pool.updates._background_update_progress_txn(
-                    txn,
-                    self.REMOVE_DELETED_DEVICES,
-                    {
-                        "device_id": rows[-1][0],
-                        "user_id": rows[-1][1],
-                        "stream_id": rows[-1][2],
-                    },
-                )
-
-            return num_deleted
-
-        number_deleted = await self.db_pool.runInteraction(
-            "_remove_deleted_devices_from_device_inbox",
-            _remove_deleted_devices_from_device_inbox_txn,
-        )
-
-        # The task is finished when no more lines are deleted.
-        if not number_deleted:
-            await self.db_pool.updates._end_background_update(
-                self.REMOVE_DELETED_DEVICES
-            )
+        ) -> Tuple[int, bool]:
 
-        return number_deleted
-
-    async def _remove_hidden_devices_from_device_inbox(
-        self, progress: JsonDict, batch_size: int
-    ) -> int:
-        """A background update that deletes all device_inboxes for hidden devices.
-
-        This should only need to be run once (when users upgrade to v1.47.0)
-
-        Args:
-            progress: JsonDict used to store progress of this background update
-            batch_size: the maximum number of rows to retrieve in a single select query
-
-        Returns:
-            The number of deleted rows
-        """
-
-        def _remove_hidden_devices_from_device_inbox_txn(
-            txn: LoggingTransaction,
-        ) -> int:
-            """stream_id is not unique
-            we need to use an inclusive `stream_id >= ?` clause,
-            since we might not have deleted all hidden device messages for the stream_id
-            returned from the previous query
-
-            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
-            to avoid problems of deleting a large number of rows all at once
-            due to a single device having lots of device messages.
-            """
+            if "max_stream_id" in progress:
+                max_stream_id = progress["max_stream_id"]
+            else:
+                txn.execute("SELECT max(stream_id) FROM device_inbox")
+                # There's a type mismatch here between how we want to type the row and
+                # what fetchone says it returns, but we silence it because we know that
+                # res can't be None.
+                res: Tuple[Optional[int]] = txn.fetchone()  # type: ignore[assignment]
+                if res[0] is None:
+                    # this can only happen if the `device_inbox` table is empty, in which
+                    # case we have no work to do.
+                    return 0, True
+                else:
+                    max_stream_id = res[0]
 
-            last_stream_id = progress.get("stream_id", 0)
+            start = progress.get("stream_id", 0)
+            stop = start + batch_size
 
+            # delete rows in `device_inbox` which do *not* correspond to a known,
+            # unhidden device.
             sql = """
-                SELECT device_id, user_id, stream_id
-                FROM device_inbox
+                DELETE FROM device_inbox
                 WHERE
-                    stream_id >= ?
-                    AND (device_id, user_id) IN (
-                        SELECT device_id, user_id FROM devices WHERE hidden = ?
+                    stream_id >= ? AND stream_id < ?
+                    AND NOT EXISTS (
+                        SELECT * FROM devices d
+                        WHERE
+                            d.device_id=device_inbox.device_id
+                            AND d.user_id=device_inbox.user_id
+                            AND NOT hidden
                     )
-                ORDER BY stream_id
-                LIMIT ?
-            """
-
-            txn.execute(sql, (last_stream_id, True, batch_size))
-            rows = txn.fetchall()
+                """
 
-            num_deleted = 0
-            for row in rows:
-                num_deleted += self.db_pool.simple_delete_txn(
-                    txn,
-                    "device_inbox",
-                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
-                )
+            txn.execute(sql, (start, stop))
 
-            if rows:
-                # We don't just save the `stream_id` in progress as
-                # otherwise it can happen in large deployments that
-                # no change of status is visible in the log file, as
-                # it may be that the stream_id does not change in several runs
-                self.db_pool.updates._background_update_progress_txn(
-                    txn,
-                    self.REMOVE_HIDDEN_DEVICES,
-                    {
-                        "device_id": rows[-1][0],
-                        "user_id": rows[-1][1],
-                        "stream_id": rows[-1][2],
-                    },
-                )
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+                {
+                    "stream_id": stop,
+                    "max_stream_id": max_stream_id,
+                },
+            )
 
-            return num_deleted
+            return stop > max_stream_id
 
-        number_deleted = await self.db_pool.runInteraction(
-            "_remove_hidden_devices_from_device_inbox",
-            _remove_hidden_devices_from_device_inbox_txn,
+        finished = await self.db_pool.runInteraction(
+            "_remove_devices_from_device_inbox_txn",
+            _remove_dead_devices_from_device_inbox_txn,
         )
 
-        # The task is finished when no more lines are deleted.
-        if not number_deleted:
+        if finished:
             await self.db_pool.updates._end_background_update(
-                self.REMOVE_HIDDEN_DEVICES
+                self.REMOVE_DEAD_DEVICES_FROM_INBOX,
             )
 
-        return number_deleted
+        return batch_size
 
 
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9ccc66e589..d5a4a661cd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {d["device_id"]: d for d in devices}
 
+    async def get_devices_by_auth_provider_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> List[Dict[str, Any]]:
+        """Retrieve the list of devices associated with a SSO IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO IdP ID as defined in the server config
+            auth_provider_session_id: The session ID within the IdP
+        Returns:
+            A list of dicts containing the device_id and the user_id of each device
+        """
+        return await self.db_pool.simple_select_list(
+            table="device_auth_providers",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            retcols=("user_id", "device_id"),
+            desc="get_devices_by_auth_provider_session_id",
+        )
+
     @trace
     async def get_device_updates_by_remote(
         self, destination: str, from_stream_id: int, limit: int
@@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
+        self,
+        user_id: str,
+        device_id: str,
+        initial_device_display_name: Optional[str],
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: id of device
             initial_device_display_name: initial displayname of the device.
                 Ignored if device exists.
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from a OIDC login.
 
         Returns:
             Whether the device was inserted or an existing device existed with that ID.
@@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
 
+            if auth_provider_id and auth_provider_session_id:
+                await self.db_pool.simple_insert(
+                    "device_auth_providers",
+                    values={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "auth_provider_id": auth_provider_id,
+                        "auth_provider_session_id": auth_provider_session_id,
+                    },
+                    desc="store_device_auth_provider",
+                )
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
@@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 keyvalues={"user_id": user_id},
             )
 
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_auth_providers",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
         await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             fallback_keys: the keys to set.  This is a map from key ID (which is
                 of the form "algorithm:id") to key data.
         """
+        await self.db_pool.runInteraction(
+            "set_e2e_fallback_keys_txn",
+            self._set_e2e_fallback_keys_txn,
+            user_id,
+            device_id,
+            fallback_keys,
+        )
+
+        await self.invalidate_cache_and_stream(
+            "get_e2e_unused_fallback_key_types", (user_id, device_id)
+        )
+
+    def _set_e2e_fallback_keys_txn(
+        self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
         # FIXME: make sure that only one key per algorithm is uploaded
         for key_id, fallback_key in fallback_keys.items():
             algorithm, key_id = key_id.split(":", 1)
-            await self.db_pool.simple_upsert(
-                "e2e_fallback_keys_json",
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_fallback_keys_json",
                 keyvalues={
                     "user_id": user_id,
                     "device_id": device_id,
                     "algorithm": algorithm,
                 },
-                values={
-                    "key_id": key_id,
-                    "key_json": json_encoder.encode(fallback_key),
-                    "used": False,
-                },
-                desc="set_e2e_fallback_key",
+                retcol="key_json",
+                allow_none=True,
             )
 
-        await self.invalidate_cache_and_stream(
-            "get_e2e_unused_fallback_key_types", (user_id, device_id)
-        )
+            new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+            # If the uploaded key is the same as the current fallback key,
+            # don't do anything.  This prevents marking the key as unused if it
+            # was already used.
+            if old_key_json != new_key_json:
+                self.db_pool.simple_upsert_txn(
+                    txn,
+                    table="e2e_fallback_keys_json",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                    },
+                    values={
+                        "key_id": key_id,
+                        "key_json": json_encoder.encode(fallback_key),
+                        "used": False,
+                    },
+                )
 
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ef5d1ef01e..9580a40785 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
                 DELETE FROM event_auth
                 WHERE event_id IN (
                     SELECT event_id FROM events
-                    LEFT JOIN state_events USING (room_id, event_id)
+                    LEFT JOIN state_events AS se USING (room_id, event_id)
                     WHERE ? <= stream_ordering AND stream_ordering < ?
-                        AND state_key IS null
+                        AND se.state_key IS null
                 )
             """
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d957e770dc..3efdd0c920 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
 ]
 
 
+class BasePushAction(TypedDict):
+    event_id: str
+    actions: List[Union[dict, str]]
+
+
+class HttpPushAction(BasePushAction):
+    room_id: str
+    stream_ordering: int
+
+
+class EmailPushAction(HttpPushAction):
+    received_ts: Optional[int]
+
+
 def _serialize_action(actions, is_highlight):
     """Custom serializer for actions. This allows us to "compress" common actions.
 
@@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[HttpPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the httppusher.
 
@@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         min_stream_ordering: int,
         max_stream_ordering: int,
         limit: int = 20,
-    ) -> List[dict]:
+    ) -> List[EmailPushAction]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the emailpusher
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 120e4807d1..4e528612ea 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
 # limitations under the License.
 import itertools
 import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
 )
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
 @attr.s(slots=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -108,23 +106,30 @@ class PersistEventsStore:
         self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
-        # Ideally we'd move these ID gens here, unfortunately some other ID
-        # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
-        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
         ), "Can only instantiate EventsStore on master"
 
+        # Since we have been configured to write, we ought to have id generators,
+        # rather than id trackers.
+        assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+        assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+        # Ideally we'd move these ID gens here, unfortunately some other ID
+        # generators are chained off them so doing so is a bit of a PITA.
+        self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
+        *,
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
-        backfilled: bool = False,
+        use_negative_stream_ordering: bool = False,
+        inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
@@ -137,7 +142,14 @@ class PersistEventsStore:
                 room state
             new_forward_extremities: Map from room_id to list of event IDs
                 that are the new forward extremities of the room.
-            backfilled
+            use_negative_stream_ordering: Whether to start stream_ordering on
+                the negative side and decrement. This should be set as True
+                for backfilled events because backfilled events get a negative
+                stream ordering so they don't come down incremental `/sync`.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
 
         Returns:
             Resolves when the events have been persisted
@@ -159,7 +171,7 @@ class PersistEventsStore:
         #
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
-        if backfilled:
+        if use_negative_stream_ordering:
             stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
@@ -176,13 +188,13 @@ class PersistEventsStore:
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
-                backfilled=backfilled,
+                inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
-            if not backfilled:
+            if stream < 0:
                 # backfilled events have negative stream orderings, so we don't
                 # want to set the event_persisted_position to that.
                 synapse.metrics.event_persisted_position.set(
@@ -316,8 +328,9 @@ class PersistEventsStore:
     def _persist_events_txn(
         self,
         txn: LoggingTransaction,
+        *,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        inhibit_local_membership_updates: bool = False,
         state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
         new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
     ):
@@ -330,7 +343,10 @@ class PersistEventsStore:
         Args:
             txn
             events_and_contexts: events to persist
-            backfilled: True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
@@ -363,9 +379,7 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(
-            txn, events_and_contexts=events_and_contexts, backfilled=backfilled
-        )
+        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -398,7 +412,7 @@ class PersistEventsStore:
             txn,
             events_and_contexts=events_and_contexts,
             all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # We call this last as it assumes we've inserted the events into
@@ -561,9 +575,9 @@ class PersistEventsStore:
         # fetch their auth event info.
         while missing_auth_chains:
             sql = """
-                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                SELECT event_id, events.type, se.state_key, chain_id, sequence_number
                 FROM events
-                INNER JOIN state_events USING (event_id)
+                INNER JOIN state_events AS se USING (event_id)
                 LEFT JOIN event_auth_chains USING (event_id)
                 WHERE
             """
@@ -1200,7 +1214,6 @@ class PersistEventsStore:
         self,
         txn,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool,
     ):
         """Update min_depth for each room
 
@@ -1208,13 +1221,18 @@ class PersistEventsStore:
             txn (twisted.enterprise.adbapi.Connection): db connection
             events_and_contexts (list[(EventBase, EventContext)]): events
                 we are persisting
-            backfilled (bool): True if the events were backfilled
         """
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
-            if not backfilled:
+            # Then update the `stream_ordering` position to mark the latest
+            # event as the front of the room. This should not be done for
+            # backfilled events because backfilled events have negative
+            # stream_ordering and happened in the past so we know that we don't
+            # need to update the stream_ordering tip/front for the room.
+            assert event.internal_metadata.stream_ordering is not None
+            if event.internal_metadata.stream_ordering >= 0:
                 txn.call_after(
                     self.store._events_stream_cache.entity_has_changed,
                     event.room_id,
@@ -1427,7 +1445,12 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _update_metadata_tables_txn(
-        self, txn, events_and_contexts, all_events_and_contexts, backfilled
+        self,
+        txn,
+        *,
+        events_and_contexts,
+        all_events_and_contexts,
+        inhibit_local_membership_updates: bool = False,
     ):
         """Update all the miscellaneous tables for new events
 
@@ -1439,7 +1462,10 @@ class PersistEventsStore:
                 events that we were going to persist. This includes events
                 we've already persisted, etc, that wouldn't appear in
                 events_and_context.
-            backfilled (bool): True if the events were backfilled
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
         """
 
         # Insert all the push actions into the event_push_actions table.
@@ -1513,7 +1539,7 @@ class PersistEventsStore:
                 for event, _ in events_and_contexts
                 if event.type == EventTypes.Member
             ],
-            backfilled=backfilled,
+            inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
         # Insert event_reference_hashes table.
@@ -1553,11 +1579,13 @@ class PersistEventsStore:
         for row in rows:
             event = ev_map[row["event_id"]]
             if not row["rejects"] and not row["redacts"]:
-                to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+                to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
         def prefill():
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+                self.store._get_event_cache.set(
+                    (cache_entry.event.event_id,), cache_entry
+                )
 
         txn.call_after(prefill)
 
@@ -1638,8 +1666,19 @@ class PersistEventsStore:
             txn, table="event_reference_hashes", values=vals
         )
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database."""
+    def _store_room_members_txn(
+        self, txn, events, *, inhibit_local_membership_updates: bool = False
+    ):
+        """
+        Store a room member in the database.
+        Args:
+            txn: The transaction to use.
+            events: List of events to store.
+            inhibit_local_membership_updates: Stop the local_current_membership
+                from being updated by these events. This should be set to True
+                for backfilled events because backfilled events in the past do
+                not affect the current local state.
+        """
 
         def non_null_str_or_none(val: Any) -> Optional[str]:
             return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1682,7 +1721,7 @@ class PersistEventsStore:
             # band membership", like a remote invite or a rejection of a remote invite.
             if (
                 self.is_mine_id(event.state_key)
-                and not backfilled
+                and not inhibit_local_membership_updates
                 and event.internal_metadata.is_outlier()
                 and event.internal_metadata.is_out_of_band_membership()
             ):
@@ -1696,34 +1735,33 @@ class PersistEventsStore:
                     },
                 )
 
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
+    def _handle_event_relations(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
+        """Handles inserting relation data during persistence of events
 
         Args:
-            txn
-            event (EventBase)
+            txn: The current database transaction.
+            event: The event which might have relations.
         """
         relation = event.content.get("m.relates_to")
         if not relation:
             # No relations
             return
 
+        # Relations must have a type and parent event ID.
         rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-            RelationTypes.THREAD,
-        ):
-            # Unknown relation type
+        if not isinstance(rel_type, str):
             return
 
         parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
+        if not isinstance(parent_id, str):
             return
 
-        aggregation_key = relation.get("key")
+        # Annotations have a key field.
+        aggregation_key = None
+        if rel_type == RelationTypes.ANNOTATION:
+            aggregation_key = relation.get("key")
 
         self.db_pool.simple_insert_txn(
             txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self._purged_chain_cover_index,
         )
 
+        # The event_thread_relation background update was replaced with the
+        # event_arbitrary_relations one, which handles any relation to avoid
+        # needed to potentially crawl the entire events table in the future.
+        self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
         self.db_pool.updates.register_background_update_handler(
-            "event_thread_relation", self._event_thread_relation
+            "event_arbitrary_relations",
+            self._event_arbitrary_relations,
         )
 
         ################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
-        """Background update handler which will store thread relations for existing events."""
+    async def _event_arbitrary_relations(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update handler which will store previously unknown relations for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+        def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+            # Fetch events and then filter based on whether the event has a
+            # relation or not.
             txn.execute(
                 """
                 SELECT event_id, json FROM event_json
-                LEFT JOIN event_relations USING (event_id)
-                WHERE event_id > ? AND event_relations.event_id IS NULL
+                WHERE event_id > ?
                 ORDER BY event_id LIMIT ?
                 """,
                 (last_event_id, batch_size),
             )
 
             results = list(txn)
-            missing_thread_relations = []
+            # (event_id, parent_id, rel_type) for each relation
+            relations_to_insert: List[Tuple[str, str, str]] = []
             for (event_id, event_json_raw) in results:
                 try:
                     event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                     )
                     continue
 
-                # If there's no relation (or it is not a thread), skip!
+                # If there's no relation, skip!
                 relates_to = event_json["content"].get("m.relates_to")
                 if not relates_to or not isinstance(relates_to, dict):
                     continue
-                if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+                # If the relation type or parent event ID is not a string, skip it.
+                #
+                # Do not consider relation types that have existed for a long time,
+                # since they will already be listed in the `event_relations` table.
+                rel_type = relates_to.get("rel_type")
+                if not isinstance(rel_type, str) or rel_type in (
+                    RelationTypes.ANNOTATION,
+                    RelationTypes.REFERENCE,
+                    RelationTypes.REPLACE,
+                ):
                     continue
 
-                # Get the parent ID.
                 parent_id = relates_to.get("event_id")
                 if not isinstance(parent_id, str):
                     continue
 
-                missing_thread_relations.append((event_id, parent_id))
+                relations_to_insert.append((event_id, parent_id, rel_type))
+
+            # Insert the missing data, note that we upsert here in case the event
+            # has already been processed.
+            if relations_to_insert:
+                self.db_pool.simple_upsert_many_txn(
+                    txn=txn,
+                    table="event_relations",
+                    key_names=("event_id",),
+                    key_values=[(r[0],) for r in relations_to_insert],
+                    value_names=("relates_to_id", "relation_type"),
+                    value_values=[r[1:] for r in relations_to_insert],
+                )
 
-            # Insert the missing data.
-            self.db_pool.simple_insert_many_txn(
-                txn=txn,
-                table="event_relations",
-                values=[
-                    {
-                        "event_id": event_id,
-                        "relates_to_Id": parent_id,
-                        "relation_type": RelationTypes.THREAD,
-                    }
-                    for event_id, parent_id in missing_thread_relations
-                ],
-            )
+                # Iterate the parent IDs and invalidate caches.
+                for parent_id in {r[1] for r in relations_to_insert}:
+                    cache_tuple = (parent_id,)
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_relations_for_event, cache_tuple
+                    )
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_aggregation_groups_for_event, cache_tuple
+                    )
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_thread_summary, cache_tuple
+                    )
 
             if results:
                 latest_event_id = results[-1][0]
                 self.db_pool.updates._background_update_progress_txn(
-                    txn, "event_thread_relation", {"last_event_id": latest_event_id}
+                    txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
                 )
 
             return len(results)
 
         num_rows = await self.db_pool.runInteraction(
-            desc="event_thread_relation", func=_event_thread_relation_txn
+            desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
         )
 
         if not num_rows:
-            await self.db_pool.updates._end_background_update("event_thread_relation")
+            await self.db_pool.updates._end_background_update(
+                "event_arbitrary_relations"
+            )
 
         return num_rows
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
 import logging
 import threading
 from typing import (
+    TYPE_CHECKING,
+    Any,
     Collection,
     Container,
     Dict,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
+    RoomVersion,
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
 # control how we batch/bulk fetch events from the database.
 # The values are plucked out of thing air to make initial sync run faster
 # on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
 
 
 @attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
     event: EventBase
     redacted_event: Optional[EventBase]
 
@@ -129,7 +145,7 @@ class _EventRow:
     json: str
     internal_metadata: str
     format_version: Optional[int]
-    room_version_id: Optional[int]
+    room_version_id: Optional[str]
     rejected_reason: Optional[str]
     redactions: List[str]
     outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
     # options controlling this.
     USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
+        self._stream_id_gen: AbstractStreamIdTracker
+        self._backfill_id_gen: AbstractStreamIdTracker
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache = LruCache(
+        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
         # ID to cache entry. Note that the returned dict may not have the
         # requested event in it if the event isn't in the DB.
         self._current_event_fetches: Dict[
-            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+            str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
         self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
+        self._event_fetch_list: List[
+            Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+        ] = []
         self._event_fetch_ongoing = 0
         event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
 
         # We define this sequence here so that it can be referenced from both
         # the DataStore and PersistEventStore.
-        def get_chain_id_txn(txn):
+        def get_chain_id_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self.event_chain_id_gen = build_sequence_generator(
             db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
             id_column="chain_id",
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[False] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[False] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> EventBase:
         ...
 
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
-        get_prev_content: bool = False,
-        allow_rejected: bool = False,
-        allow_none: Literal[True] = False,
-        check_room_id: Optional[str] = None,
+        get_prev_content: bool = ...,
+        allow_rejected: bool = ...,
+        allow_none: Literal[True] = ...,
+        check_room_id: Optional[str] = ...,
     ) -> Optional[EventBase]:
         ...
 
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_events(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
         # same dict into itself N times).
         already_fetching_ids: Set[str] = set()
         already_fetching_deferreds: Set[
-            ObservableDeferred[Dict[str, _EventCacheEntry]]
+            ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = set()
 
         for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
             # function returning more events than requested, but that can happen
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
-                Dict[str, _EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred())
+                Dict[str, EventCacheEntry]
+            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
 
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id):
+    def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
-    ) -> Dict[str, _EventCacheEntry]:
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
 
         May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn: Connection) -> None:
+    def _maybe_start_fetch_thread(self) -> None:
+        """Starts an event fetch thread if we are not yet at the maximum number."""
+        with self._event_fetch_lock:
+            if (
+                self._event_fetch_list
+                and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+            ):
+                self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+                # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process("fetch_events", self._fetch_thread)
+
+    async def _fetch_thread(self) -> None:
+        """Services requests for events from `_event_fetch_list`."""
+        exc = None
+        try:
+            await self.db_pool.runWithConnection(self._fetch_loop)
+        except BaseException as e:
+            exc = e
+            raise
+        finally:
+            should_restart = False
+            event_fetches_to_fail = []
+            with self._event_fetch_lock:
+                self._event_fetch_ongoing -= 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+                # There may still be work remaining in `_event_fetch_list` if we
+                # failed, or it was added in between us deciding to exit and
+                # decrementing `_event_fetch_ongoing`.
+                if self._event_fetch_list:
+                    if exc is None:
+                        # We decided to exit, but then some more work was added
+                        # before `_event_fetch_ongoing` was decremented.
+                        # If a new event fetch thread was not started, we should
+                        # restart ourselves since the remaining event fetch threads
+                        # may take a while to get around to the new work.
+                        #
+                        # Unfortunately it is not possible to tell whether a new
+                        # event fetch thread was started, so we restart
+                        # unconditionally. If we are unlucky, we will end up with
+                        # an idle fetch thread, but it will time out after
+                        # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+                        # in any case.
+                        #
+                        # Note that multiple fetch threads may run down this path at
+                        # the same time.
+                        should_restart = True
+                    elif isinstance(exc, Exception):
+                        if self._event_fetch_ongoing == 0:
+                            # We were the last remaining fetcher and failed.
+                            # Fail any outstanding fetches since no one else will
+                            # handle them.
+                            event_fetches_to_fail = self._event_fetch_list
+                            self._event_fetch_list = []
+                        else:
+                            # We weren't the last remaining fetcher, so another
+                            # fetcher will pick up the work. This will either happen
+                            # after their existing work, however long that takes,
+                            # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+                            # they are idle.
+                            pass
+                    else:
+                        # The exception is a `SystemExit`, `KeyboardInterrupt` or
+                        # `GeneratorExit`. Don't try to do anything clever here.
+                        pass
+
+            if should_restart:
+                # We exited cleanly but noticed more work.
+                self._maybe_start_fetch_thread()
+
+            if event_fetches_to_fail:
+                # We were the last remaining fetcher and failed.
+                # Fail any outstanding fetches since no one else will handle them.
+                assert exc is not None
+                with PreserveLoggingContext():
+                    for _, deferred in event_fetches_to_fail:
+                        deferred.errback(exc)
+
+    def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        try:
-            i = 0
-            while True:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if (
-                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                            or single_threaded
-                            or i > EVENT_QUEUE_ITERATIONS
-                        ):
-                            break
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    # There are no requests waiting. If we haven't yet reached the
+                    # maximum iteration limit, wait for some more requests to turn up.
+                    # Otherwise, bail out.
+                    single_threaded = self.database_engine.single_threaded
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
+                        return
+
+                    self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                    i += 1
+                    continue
+                i = 0
 
-                self._fetch_event_list(conn, event_list)
-        finally:
-            self._event_fetch_ongoing -= 1
-            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+            self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
-        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+        self,
+        conn: LoggingDatabaseConnection,
+        event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
     ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire():
+                def fire() -> None:
                     for _, d in event_list:
                         d.callback(row_dict)
 
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
                 logger.exception("do_fetch")
 
                 # We only want to resolve deferreds from the main thread
-                def fire(evs, exc):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(exc)
+                def fire_errback(exc: Exception) -> None:
+                    for _, d in event_list:
+                        d.errback(exc)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, e)
+                    self.hs.get_reactor().callFromThread(fire_errback, e)
 
     async def _get_events_from_db(
-        self, event_ids: Iterable[str]
-    ) -> Dict[str, _EventCacheEntry]:
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
         May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
-        fetched_events = {}
+        fetched_event_ids: Set[str] = set()
+        fetched_events: Dict[str, _EventRow] = {}
         events_to_fetch = event_ids
 
         while events_to_fetch:
             row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
-            redaction_ids = set()
+            redaction_ids: Set[str] = set()
             for event_id in events_to_fetch:
                 row = row_map.get(event_id)
-                fetched_events[event_id] = row
+                fetched_event_ids.add(event_id)
                 if row:
+                    fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            events_to_fetch = redaction_ids.difference(fetched_event_ids)
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
         # build a map from event_id to EventBase
-        event_map = {}
+        event_map: Dict[str, EventBase] = {}
         for event_id, row in fetched_events.items():
-            if not row:
-                continue
             assert row.event_id == event_id
 
             rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             room_version_id = row.room_version_id
 
+            room_version: Optional[RoomVersion]
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
                 # arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         # finally, we can decide whether each one needs redacting, and build
         # the cache entries.
-        result_map = {}
+        result_map: Dict[str, EventCacheEntry] = {}
         for event_id, original_ev in event_map.items():
             redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
 
-            cache_entry = _EventCacheEntry(
+            cache_entry = EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
 
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+    async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
             that weren't requested.
         """
 
-        events_d = defer.Deferred()
+        events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
-
             self._event_fetch_lock.notify()
 
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            run_as_background_process(
-                "fetch_events", self.db_pool.runWithConnection, self._do_fetch
-            )
+        self._maybe_start_fetch_thread()
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    async def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
             event_ids: events we are looking for
 
         Returns:
-            set[str]: The events we have already seen.
+            The set of events we have already seen.
         """
         res = await self._have_seen_events_dict(
             (room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
         }
         results = {x: True for x in cache_results}
 
-        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+        def have_seen_events_txn(
+            txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+        ) -> None:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
         return results
 
     @cached(max_entries=100000, tree=True)
-    async def have_seen_event(self, room_id: str, event_id: str):
+    async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
         # this only exists for the benefit of the @cachedList descriptor on
         # _have_seen_events_dict
         raise NotImplementedError()
 
-    def _get_current_state_event_counts_txn(self, txn, room_id):
+    def _get_current_state_event_counts_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> int:
         """
         See get_current_state_event_counts.
         """
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    async def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
         more resources.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            dict[str:int] of complexity version to complexity.
+            dict[str:float] of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self):
+    def get_current_events_token(self) -> int:
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_all_new_forward_event_rows(txn):
+        def get_all_new_forward_event_rows(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
             EventsStreamRow.
         """
 
-        def get_ex_outlier_stream_rows_txn(txn):
+        def get_ex_outlier_stream_rows_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
-            return txn.fetchall()
+            return cast(
+                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+            )
 
         return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_backfill_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_backfill_event_rows(txn):
+        def get_all_new_backfill_event_rows(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
             sql = (
                 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
                 "  AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (-last_id, -current_id, instance_name, limit))
-            new_event_updates = [(row[0], row[1:]) for row in txn]
+            new_event_updates: List[
+                Tuple[int, Tuple[str, str, str, str, str, str]]
+            ] = []
+            row: Tuple[int, str, str, str, str, str, str]
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             limited = False
             if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = (
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " se.state_key, redacts, relates_to_id"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN state_events AS se USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " ORDER BY event_stream_ordering DESC"
             )
             txn.execute(sql, (-last_id, -upper_bound, instance_name))
-            new_event_updates.extend((row[0], row[1:]) for row in txn)
+            # Type safety: iterating over `txn` yields `Tuple`, i.e.
+            # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+            # variadic tuple to a fixed length tuple and flags it up as an error.
+            for row in txn:  # type: ignore[assignment]
+                new_event_updates.append((row[0], row[1:]))
 
             if len(new_event_updates) >= limit:
                 upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_current_state_deltas(
         self, instance_name: str, from_token: int, to_token: int, target_row_count: int
-    ) -> Tuple[List[Tuple], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
         """Fetch updates from current_state_delta_stream
 
         Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
                * `limited` is whether there are more updates to fetch.
         """
 
-        def get_all_updated_current_state_deltas_txn(txn):
+        def get_all_updated_current_state_deltas_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
                 ORDER BY stream_id ASC LIMIT ?
             """
             txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
-        def get_deltas_for_stream_id_txn(txn, stream_id):
+        def get_deltas_for_stream_id_txn(
+            txn: LoggingTransaction, stream_id: int
+        ) -> List[Tuple[int, str, str, str, str]]:
             sql = """
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE stream_id = ?
             """
             txn.execute(sql, [stream_id])
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
 
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows: List[Tuple] = await self.db_pool.runInteraction(
+        rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    async def is_event_after(self, event_id1, event_id2):
+    async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
         """Returns True if event_id1 is after event_id2 in the stream"""
         to_1, so_1 = await self.get_event_ordering(event_id1)
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
     @cached(max_entries=5000)
-    async def get_event_ordering(self, event_id):
+    async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
         res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
             None otherwise.
         """
 
-        def get_next_event_to_expire_txn(txn):
+        def get_next_event_to_expire_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, int]]:
             txn.execute(
                 """
                 SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
                 """
             )
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
         return mapping
 
     @wrap_as_background_process("_cleanup_old_transaction_ids")
-    async def _cleanup_old_transaction_ids(self):
+    async def _cleanup_old_transaction_ids(self) -> None:
         """Cleans out transaction id mappings older than 24hrs."""
 
-        def _cleanup_old_transaction_ids_txn(txn):
+        def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
             "_cleanup_old_transaction_ids",
             _cleanup_old_transaction_ids_txn,
         )
+
+    async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a backward gap of missing events.
+        <latest messages> A(False)--->B(False)--->C(True)--->  <gap, unknown events> <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question has any of its prev_events listed as a
+            # backward extremity, it's next to a gap.
+            #
+            # We can't just check the backward edges in `event_edges` because
+            # when we persist events, we will also record the prev_events as
+            # edges to the event in question regardless of whether we have those
+            # prev_events yet. We need to check whether those prev_events are
+            # backward extremities, also known as gaps, that need to be
+            # backfilled.
+            backward_extremity_query = """
+                SELECT 1 FROM event_backward_extremities
+                WHERE
+                    room_id = ?
+                    AND %s
+                LIMIT 1
+            """
+
+            # If the event in question is a backward extremity or has any of its
+            # prev_events listed as a backward extremity, it's next to a
+            # backward gap.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "event_id",
+                [event.event_id] + list(event.prev_event_ids()),
+            )
+
+            txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+            backward_extremities = txn.fetchall()
+
+            # We consider any backward extremity as a backward gap
+            if len(backward_extremities):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_backward_gap_txn",
+            is_event_next_to_backward_gap_txn,
+        )
+
+    async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+        """Check if the given event is next to a forward gap of missing events.
+        The gap in front of the latest events is not considered a gap.
+        <latest messages> A(False)--->B(False)--->C(False)--->  <gap, unknown events> <oldest messages>
+        <latest messages> A(False)--->B(False)--->  <gap, unknown events>  --->D(True)--->E(False) <oldest messages>
+
+        Args:
+            room_id: room where the event lives
+            event_id: event to check
+
+        Returns:
+            Boolean indicating whether it's an extremity
+        """
+
+        def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+            # If the event in question is a forward extremity, we will just
+            # consider any potential forward gap as not a gap since it's one of
+            # the latest events in the room.
+            #
+            # `event_forward_extremities` does not include backfilled or outlier
+            # events so we can't rely on it to find forward gaps. We can only
+            # use it to determine whether a message is the latest in the room.
+            #
+            # We can't combine this query with the `forward_edge_query` below
+            # because if the event in question has no forward edges (isn't
+            # referenced by any other event's prev_events) but is in
+            # `event_forward_extremities`, we don't want to return 0 rows and
+            # say it's next to a gap.
+            forward_extremity_query = """
+                SELECT 1 FROM event_forward_extremities
+                WHERE
+                    room_id = ?
+                    AND event_id = ?
+                LIMIT 1
+            """
+
+            # Check to see whether the event in question is already referenced
+            # by another event. If we don't see any edges, we're next to a
+            # forward gap.
+            forward_edge_query = """
+                SELECT 1 FROM event_edges
+                /* Check to make sure the event referencing our event in question is not rejected */
+                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                WHERE
+                    event_edges.room_id = ?
+                    AND event_edges.prev_event_id = ?
+                    /* It's not a valid edge if the event referencing our event in
+                     * question is rejected.
+                     */
+                    AND rejections.event_id IS NULL
+                LIMIT 1
+            """
+
+            # We consider any forward extremity as the latest in the room and
+            # not a forward gap.
+            #
+            # To expand, even though there is technically a gap at the front of
+            # the room where the forward extremities are, we consider those the
+            # latest messages in the room so asking other homeservers for more
+            # is useless. The new latest messages will just be federated as
+            # usual.
+            txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+            forward_extremities = txn.fetchall()
+            if len(forward_extremities):
+                return False
+
+            # If there are no forward edges to the event in question (another
+            # event hasn't referenced this event in their prev_events), then we
+            # assume there is a forward gap in the history.
+            txn.execute(forward_edge_query, (event.room_id, event.event_id))
+            forward_edges = txn.fetchall()
+            if not len(forward_edges):
+                return True
+
+            return False
+
+        return await self.db_pool.runInteraction(
+            "is_event_next_to_gap_txn",
+            is_event_next_to_gap_txn,
+        )
+
+    async def get_event_id_for_timestamp(
+        self, room_id: str, timestamp: int, direction: str
+    ) -> Optional[str]:
+        """Find the closest event to the given timestamp in the given direction.
+
+        Args:
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            The closest event_id otherwise None if we can't find any event in
+            the given direction.
+        """
+
+        sql_template = """
+            SELECT event_id FROM events
+            LEFT JOIN rejections USING (event_id)
+            WHERE
+                origin_server_ts %s ?
+                AND room_id = ?
+                /* Make sure event is not rejected */
+                AND rejections.event_id IS NULL
+            ORDER BY origin_server_ts %s
+            LIMIT 1;
+        """
+
+        def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+            if direction == "b":
+                # Find closest event *before* a given timestamp. We use descending
+                # (which gives values largest to smallest) because we want the
+                # largest possible timestamp *before* the given timestamp.
+                comparison_operator = "<="
+                order = "DESC"
+            else:
+                # Find closest event *after* a given timestamp. We use ascending
+                # (which gives values smallest to largest) because we want the
+                # closest possible timestamp *after* the given timestamp.
+                comparison_operator = ">="
+                order = "ASC"
+
+            txn.execute(
+                sql_template % (comparison_operator, order), (timestamp, room_id)
+            )
+            row = txn.fetchone()
+            if row:
+                (event_id,) = row
+                return event_id
+
+            return None
+
+        if direction not in ("f", "b"):
+            raise ValueError("Unknown direction: %s" % (direction,))
+
+        return await self.db_pool.runInteraction(
+            "get_event_id_for_timestamp_txn",
+            get_event_id_for_timestamp_txn,
+        )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3eb30944bf..91b0576b85 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         logger.info("[purge] looking for events to delete")
 
-        should_delete_expr = "state_key IS NULL"
+        should_delete_expr = "state_events.state_key IS NULL"
         should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen: Union[
-                StreamIdGenerator, SlavedIdTracker
-            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5e55440570..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -84,28 +84,37 @@ class TokenLookupResult:
         return self.user_id
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(auto_attribs=True, frozen=True, slots=True)
 class RefreshTokenLookupResult:
     """Result of looking up a refresh token."""
 
-    user_id = attr.ib(type=str)
+    user_id: str
     """The user this token belongs to."""
 
-    device_id = attr.ib(type=str)
+    device_id: str
     """The device associated with this refresh token."""
 
-    token_id = attr.ib(type=int)
+    token_id: int
     """The ID of this refresh token."""
 
-    next_token_id = attr.ib(type=Optional[int])
+    next_token_id: Optional[int]
     """The ID of the refresh token which replaced this one."""
 
-    has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+    has_next_refresh_token_been_refreshed: bool
     """True if the next refresh token was used for another refresh."""
 
-    has_next_access_token_been_used = attr.ib(type=bool)
+    has_next_access_token_been_used: bool
     """True if the next access token was already used at least once."""
 
+    expiry_ts: Optional[int]
+    """The time at which the refresh token expires and can not be used.
+    If None, the refresh token doesn't expire."""
+
+    ultimate_session_expiry_ts: Optional[int]
+    """The time at which the session comes to an end and can no longer be
+    refreshed.
+    If None, the session can be refreshed indefinitely."""
+
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
@@ -1198,8 +1207,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         expiration_ts = now_ms + self._account_validity_period
 
         if use_delta:
+            assert self._account_validity_startup_job_max_delta is not None
             expiration_ts = random.randrange(
-                expiration_ts - self._account_validity_startup_job_max_delta,
+                int(expiration_ts - self._account_validity_startup_job_max_delta),
                 expiration_ts,
             )
 
@@ -1625,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     rt.user_id,
                     rt.device_id,
                     rt.next_token_id,
-                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
-                    at.used has_next_access_token_been_used
+                    (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+                    at.used AS has_next_access_token_been_used,
+                    rt.expiry_ts,
+                    rt.ultimate_session_expiry_ts
                 FROM refresh_tokens rt
                 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
                 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1647,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 has_next_refresh_token_been_refreshed=row[4],
                 # This column is nullable, ensure it's a boolean
                 has_next_access_token_been_used=(row[5] or False),
+                expiry_ts=row[6],
+                ultimate_session_expiry_ts=row[7],
             )
 
         return await self.db_pool.runInteraction(
@@ -1728,11 +1742,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
         )
 
         self.db_pool.updates.register_background_update_handler(
-            "user_threepids_grandfather", self._bg_user_threepids_grandfather
+            "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-        self.db_pool.updates.register_background_update_handler(
-            "users_set_deactivated_flag", self._background_update_set_deactivated_flag
+        self.db_pool.updates.register_noop_background_update(
+            "user_threepids_grandfather"
         )
 
         self.db_pool.updates.register_background_index_update(
@@ -1805,35 +1819,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
         return nb_processed
 
-    async def _bg_user_threepids_grandfather(self, progress, batch_size):
-        """We now track which identity servers a user binds their 3PID to, so
-        we need to handle the case of existing bindings where we didn't track
-        this.
-
-        We do this by grandfathering in existing user threepids assuming that
-        they used one of the server configured trusted identity servers.
-        """
-        id_servers = set(self.config.registration.trusted_third_party_id_servers)
-
-        def _bg_user_threepids_grandfather_txn(txn):
-            sql = """
-                INSERT INTO user_threepid_id_server
-                    (user_id, medium, address, id_server)
-                SELECT user_id, medium, address, ?
-                FROM user_threepids
-            """
-
-            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
-
-        if id_servers:
-            await self.db_pool.runInteraction(
-                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
-            )
-
-        await self.db_pool.updates._end_background_update("user_threepids_grandfather")
-
-        return 1
-
     async def set_user_deactivated_status(
         self, user_id: str, deactivated: bool
     ) -> None:
@@ -1943,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         user_id: str,
         token: str,
         device_id: Optional[str],
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> int:
         """Adds a refresh token for the given user.
 
@@ -1950,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             user_id: The user ID.
             token: The new access token to add.
             device_id: ID of the device to associate with the refresh token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1965,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "device_id": device_id,
                 "token": token,
                 "next_token_id": None,
+                "expiry_ts": expiry_ts,
+                "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
             },
             desc="add_refresh_token_to_user",
         )
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
+    async def event_includes_relation(self, event_id: str) -> bool:
+        """Check if the given event relates to another event.
+
+        An event has a relation if it has a valid m.relates_to with a rel_type
+        and event_id in the content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$other_event_id"
+                }
+            }
+        }
+
+        Args:
+            event_id: The event to check.
+
+        Returns:
+            True if the event includes a valid relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"event_id": event_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_includes_relation",
+        )
+        return result is not None
+
+    async def event_is_target_of_relation(self, parent_id: str) -> bool:
+        """Check if the given event is the target of another event's relation.
+
+        An event is the target of an event relation if it has a valid
+        m.relates_to with a rel_type and event_id pointing to parent_id in the
+        content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$parent_id"
+                }
+            }
+        }
+
+        Args:
+            parent_id: The event to check.
+
+        Returns:
+            True if the event is the target of another event's relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"relates_to_id": parent_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_is_target_of_relation",
+        )
+        return result is not None
+
     @cached(tree=True)
     async def get_aggregation_groups_for_event(
         self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 %s;
         """
 
-        def _get_if_event_has_relations(txn) -> List[str]:
+        def _get_if_events_have_relations(txn) -> List[str]:
             clauses: List[str] = []
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
             return [row[0] for row in txn]
 
         return await self.db_pool.runInteraction(
-            "get_if_event_has_relations", _get_if_event_has_relations
+            "get_if_events_have_relations", _get_if_events_have_relations
         )
 
     async def has_user_annotated_event(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 17b398bb69..7d694d852d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
             desc="is_room_blocked",
         )
 
+    async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
+        """
+        Function to retrieve user who has blocked the room.
+        user_id is non-nullable
+        It returns None if the room is not blocked.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="room_is_blocked_by",
+        )
+
     async def get_rooms_paginate(
         self,
         start: int,
@@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             self.is_room_blocked,
             (room_id,),
         )
+
+    async def unblock_room(self, room_id: str) -> None:
+        """Remove the room from blocking list.
+
+        Args:
+            room_id: Room to unblock
+        """
+        await self.db_pool.simple_delete(
+            table="blocked_rooms",
+            keyvalues={"room_id": room_id},
+            desc="unblock_room",
+        )
+        await self.db_pool.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked,
+            (room_id,),
+        )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 033a9831d6..6b2a8d06a6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND c.membership = ?
             """
         else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
                     c.type = 'm.room.member'
-                    AND state_key = ?
+                    AND c.state_key = ?
                     AND m.membership = ?
             """
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 42dc807d17..57aab55259 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 oldest `limit` events.
 
         Returns:
-            The list of events (in ascending order) and the token from the start
+            The list of events (in ascending stream order) and the token from the start
             of the chunk of events returned.
         """
         if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if not has_changed:
             return [], from_key
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
+        """Fetch membership events for a given user.
+
+        All such events whose stream ordering `s` lies in the range
+        `from_key < s <= to_key` are returned. Events are ordered by ascending stream
+        order.
+        """
+        # Start by ruling out cases where a DB query is not necessary.
         if from_key == to_key:
             return []
 
@@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             if not has_changed:
                 return []
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
             # To handle tokens with a non-empty instance_map we fetch more
             # results than necessary and then filter down
             min_from_id = from_key.stream
@@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         Returns:
             A list of events and a token pointing to the start of the returned
-            events. The events returned are in ascending order.
+            events. The events returned are in ascending topological order.
         """
 
         rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d7dc1f73ac..1622822552 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,6 +14,7 @@
 
 import logging
 from collections import namedtuple
+from enum import Enum
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 import attr
@@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
 )
 
 
+class DestinationSortOrder(Enum):
+    """Enum to define the sorting method used when returning destinations."""
+
+    DESTINATION = "destination"
+    RETRY_LAST_TS = "retry_last_ts"
+    RETTRY_INTERVAL = "retry_interval"
+    FAILURE_TS = "failure_ts"
+    LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class DestinationRetryTimings:
     """The current destination retry timing info for a remote server."""
@@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
 
         destinations = [row[0] for row in txn]
         return destinations
+
+    async def get_destinations_paginate(
+        self,
+        start: int,
+        limit: int,
+        destination: Optional[str] = None,
+        order_by: str = DestinationSortOrder.DESTINATION.value,
+        direction: str = "f",
+    ) -> Tuple[List[JsonDict], int]:
+        """Function to retrieve a paginated list of destinations.
+        This will return a json list of destinations and the
+        total number of destinations matching the filter criteria.
+
+        Args:
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            destination: search string in destination
+            order_by: the sort order of the returned list
+            direction: sort ascending or descending
+        Returns:
+            A tuple of a list of mappings from destination to information
+            and a count of total destinations.
+        """
+
+        def get_destinations_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
+            order_by_column = DestinationSortOrder(order_by).value
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            args = []
+            where_statement = ""
+            if destination:
+                args.extend(["%" + destination.lower() + "%"])
+                where_statement = "WHERE LOWER(destination) LIKE ?"
+
+            sql_base = f"FROM destinations {where_statement} "
+            sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = f"""
+                SELECT destination, retry_last_ts, retry_interval, failure_ts,
+                last_successful_stream_ordering
+                {sql_base}
+                ORDER BY {order_by_column} {order}, destination ASC
+                LIMIT ? OFFSET ?
+            """
+            txn.execute(sql, args + [limit, start])
+            destinations = self.db_pool.cursor_to_dict(txn)
+            return destinations, count
+
+        return await self.db_pool.runInteraction(
+            "get_destinations_paginate_txn", get_destinations_paginate_txn
+        )
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 402f134d89..428d66a617 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -583,7 +583,8 @@ class EventsPersistenceStorage:
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremeties=new_forward_extremeties,
-                backfilled=backfilled,
+                use_negative_stream_ordering=backfilled,
+                inhibit_local_membership_updates=backfilled,
             )
 
             await self._handle_potentially_left_users(potentially_left_users)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8b9c6adae2..e45adfcb55 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -131,24 +131,16 @@ def prepare_database(
                     "config==None in prepare_database, but database is not empty"
                 )
 
-            # if it's a worker app, refuse to upgrade the database, to avoid multiple
-            # workers doing it at once.
-            if config.worker.worker_app is None:
-                _upgrade_existing_database(
-                    cur,
-                    version_info,
-                    database_engine,
-                    config,
-                    databases=databases,
-                )
-            elif version_info.current_version < SCHEMA_VERSION:
-                # If the DB is on an older version than we expect then we refuse
-                # to start the worker (as the main process needs to run first to
-                # update the schema).
-                raise UpgradeDatabaseException(
-                    OUTDATED_SCHEMA_ON_WORKER_ERROR
-                    % (SCHEMA_VERSION, version_info.current_version)
-                )
+            # This should be run on all processes, master or worker. The master will
+            # apply the deltas, while workers will check if any outstanding deltas
+            # exist and raise an PrepareDatabaseException if they do.
+            _upgrade_existing_database(
+                cur,
+                version_info,
+                database_engine,
+                config,
+                databases=databases,
+            )
 
         else:
             logger.info("%r: Initialising new database", databases)
@@ -358,6 +350,18 @@ def _upgrade_existing_database(
 
     is_worker = config and config.worker.worker_app is not None
 
+    # If the schema version needs to be updated, and we are on a worker, we immediately
+    # know to bail out as workers cannot update the database schema. Only one process
+    # must update the database at the time, therefore we delegate this task to the master.
+    if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
+        # If the DB is on an older version than we expect then we refuse
+        # to start the worker (as the main process needs to run first to
+        # update the schema).
+        raise UpgradeDatabaseException(
+            OUTDATED_SCHEMA_ON_WORKER_ERROR
+            % (SCHEMA_VERSION, current_schema_state.current_version)
+        )
+
     if (
         current_schema_state.compat_version is not None
         and current_schema_state.compat_version > SCHEMA_VERSION
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 3a00ed6835..50d08094d5 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 65  # remember to update the list below when updating
+SCHEMA_VERSION = 66  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -46,6 +46,10 @@ Changes in SCHEMA_VERSION = 65:
     - MSC2716: Remove unique event_id constraint from insertion_event_edges
       because an insertion event can have multiple edges.
     - Remove unused tables `user_stats_historical` and `room_stats_historical`.
+
+Changes in SCHEMA_VERSION = 66:
+    - Queries on state_key columns are now disambiguated (ie, the codebase can handle
+      the `events` table having a `state_key` column).
 """
 
 
diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..82f6408b36
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- when a device was deleted using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+-- Remove any existing instances of this job running. It's OK to stop and restart this job,
+-- as it's just deleting entries from a table - no progress will be lost.
+--
+-- This is necessary due a similar migration running the job accidentally
+-- being included in schema version 64 during v1.47.0rc1,rc2. If a
+-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
+-- then they would have started running this background update already.
+-- If that update was still running, then simply inserting it again would
+-- cause an SQL failure. So we effectively do an "upsert" here instead.
+
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6506, 'remove_deleted_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
index d60517f7b4..267b2cb539 100644
--- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql
+++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
@@ -15,4 +15,4 @@
 
 -- Check old events for thread relations.
 INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
-  (6502, 'event_thread_relation', '{}');
+  (6507, 'event_arbitrary_relations', '{}');
diff --git a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
index 076179123d..d79455c2ce 100644
--- a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql
+++ b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
@@ -13,10 +13,6 @@
  * limitations under the License.
  */
 
-
--- Remove messages from the device_inbox table which were orphaned
--- when a device was deleted using Synapse earlier than 1.47.0.
--- This runs as background task, but may take a bit to finish.
-
+-- Background update to clear the inboxes of hidden and deleted devices.
 INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
-  (6505, 'remove_deleted_devices_from_device_inbox', '{}');
+  (6508, 'remove_dead_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+  -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+  -- They may not be used after they have expired.
+  -- If null, then the refresh token's lifetime is unlimited.
+  ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+  -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+  -- No matter how much the access and refresh tokens are refreshed, they cannot
+  -- be extended past this time.
+  -- If null, then the session length is unlimited.
+  ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
new file mode 100644
index 0000000000..a65bfb520d
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql
@@ -0,0 +1,27 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track the auth provider used by each login as well as the session ID
+CREATE TABLE device_auth_providers (
+  user_id TEXT NOT NULL,
+  device_id TEXT NOT NULL,
+  auth_provider_id TEXT NOT NULL,
+  auth_provider_session_id TEXT NOT NULL
+);
+
+CREATE INDEX device_auth_providers_devices
+  ON device_auth_providers (user_id, device_id);
+CREATE INDEX device_auth_providers_sessions
+  ON device_auth_providers (auth_provider_id, auth_provider_session_id);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def get_next(self) -> AsyncContextManager[int]:
-        raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+    """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+    Stream IDs are monotonically increasing or decreasing integers representing write
+    transactions. The "current" stream ID is the stream ID such that all transactions
+    with equal or smaller stream IDs have completed. Since transactions may complete out
+    of order, this is not the same as the stream ID of the last completed transaction.
+
+    Completed transactions include both committed transactions and transactions that
+    have been rolled back.
+    """
 
     @abc.abstractmethod
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+    def advance(self, instance_name: str, new_id: int) -> None:
+        """Advance the position of the named writer to the given ID, if greater
+        than existing entry.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+
+        Returns:
+            The maximum stream id.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to `get_current_token`.
+        """
+        raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+    """Generates stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
+
+    See `AbstractStreamIdTracker` for more details.
+    """
+
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        """
+        Usage:
+            async with stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        """
+        Usage:
+            async with stream_id_gen.get_next(n) as stream_ids:
+                # ... persist events ...
+        """
         raise NotImplementedError()
 
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Used to generate new stream ids when persisting events while keeping
-    track of which transactions have been completed.
+    """Generates and tracks stream IDs for a stream with a single writer.
 
-    This allows us to get the "current" stream id, i.e. the stream id such that
-    all ids less than or equal to it have completed. This handles the fact that
-    persistence of events can complete out of order.
+    This class must only be used when the current Synapse process is the sole
+    writer for a stream.
 
     Args:
         db_conn(connection):  A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+    def advance(self, instance_name: str, new_id: int) -> None:
+        # `StreamIdGenerator` should only be used when there is a single writer,
+        # so replication should never happen.
+        raise Exception("Replication is not supported by StreamIdGenerator")
+
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
         with self._lock:
             self._current += self._step
             next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next(n) as stream_ids:
-                # ... persist events ...
-        """
         with self._lock:
             next_ids = range(
                 self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-
-        Returns:
-            The maximum stream id.
-        """
         with self._lock:
             if self._unfinished_ids:
                 return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
             return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
 
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
-    """An ID generator that tracks a stream that can have multiple writers.
+    """Generates and tracks stream IDs for a stream with multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other
     writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next_mult(5) as stream_ids:
-                # ... persist events ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer."""
-
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
         #
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             }
 
     def advance(self, instance_name: str, new_id: int) -> None:
-        """Advance the position of the named writer to the given ID, if greater
-        than existing entry.
-        """
-
         new_id *= self._return_factor
 
         with self._lock:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 95f23e27b6..f157132210 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,9 +14,8 @@
 
 import json
 import logging
-import re
 import typing
-from typing import Any, Callable, Dict, Generator, Optional, Pattern
+from typing import Any, Callable, Dict, Generator, Optional
 
 import attr
 from frozendict import frozendict
@@ -35,9 +34,6 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-_WILDCARD_RUN = re.compile(r"([\?\*]+)")
-
-
 def _reject_invalid_json(val: Any) -> None:
     """Do not allow Infinity, -Infinity, or NaN values in JSON."""
     raise ValueError("Invalid JSON value: '%s'" % val)
@@ -185,56 +181,3 @@ def log_failure(
     if not consumeErrors:
         return failure
     return None
-
-
-def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern:
-    """Converts a glob to a compiled regex object.
-
-    Args:
-        glob: pattern to match
-        word_boundary: If True, the pattern will be allowed to match at word boundaries
-           anywhere in the string. Otherwise, the pattern is anchored at the start and
-           end of the string.
-
-    Returns:
-        compiled regex pattern
-    """
-
-    # Patterns with wildcards must be simplified to avoid performance cliffs
-    # - The glob `?**?**?` is equivalent to the glob `???*`
-    # - The glob `???*` is equivalent to the regex `.{3,}`
-    chunks = []
-    for chunk in _WILDCARD_RUN.split(glob):
-        # No wildcards? re.escape()
-        if not _WILDCARD_RUN.match(chunk):
-            chunks.append(re.escape(chunk))
-            continue
-
-        # Wildcards? Simplify.
-        qmarks = chunk.count("?")
-        if "*" in chunk:
-            chunks.append(".{%d,}" % qmarks)
-        else:
-            chunks.append(".{%d}" % qmarks)
-
-    res = "".join(chunks)
-
-    if word_boundary:
-        res = re_word_boundary(res)
-    else:
-        # \A anchors at start of string, \Z at end of string
-        res = r"\A" + res + r"\Z"
-
-    return re.compile(res, re.IGNORECASE)
-
-
-def re_word_boundary(r: str) -> str:
-    """
-    Adds word boundary characters to the start and end of an
-    expression to require that the match occur as a whole word,
-    but do so respecting the fact that strings starting or ending
-    with non-word characters will change word boundaries.
-    """
-    # we can't use \b as it chokes on unicode. however \W seems to be okay
-    # as shorthand for [^0-9A-Za-z_].
-    return r"(^|\W)%s(\W|$)" % (r,)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 561b962e14..20ce294209 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -27,6 +27,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    Iterator,
     Optional,
     Set,
     TypeVar,
@@ -40,7 +41,6 @@ from typing_extensions import ContextManager
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", [])
 
-        def callback(r):
+        def callback(r: _T) -> _T:
             object.__setattr__(self, "_result", (True, r))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
                     )
             return r
 
-        def errback(f):
+        def errback(f: Failure) -> Optional[Failure]:
             object.__setattr__(self, "_result", (False, f))
 
             # once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
             for observer in observers:
                 # This is a little bit of magic to correctly propagate stack
                 # traces when we `await` on one of the observer deferreds.
-                f.value.__failure__ = f
+                f.value.__failure__ = f  # type: ignore[union-attr]
                 try:
                     observer.errback(f)
                 except Exception as e:
@@ -314,7 +314,7 @@ class Linearizer:
         # will release the lock.
 
         @contextmanager
-        def _ctx_manager(_):
+        def _ctx_manager(_: None) -> Iterator[None]:
             try:
                 yield
             finally:
@@ -355,7 +355,7 @@ class Linearizer:
         new_defer = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
-        def cb(_r):
+        def cb(_r: None) -> "defer.Deferred[None]":
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
             entry.count += 1
 
@@ -371,7 +371,7 @@ class Linearizer:
             # code must be synchronous, so this is the only sensible place.)
             return self._clock.sleep(0)
 
-        def eb(e):
+        def eb(e: Failure) -> Failure:
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
                 logger.debug(
@@ -435,7 +435,7 @@ class ReadWriteLock:
             await make_deferred_yieldable(curr_writer)
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -464,7 +464,7 @@ class ReadWriteLock:
         await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
                 yield
             finally:
@@ -524,7 +524,7 @@ def timeout_deferred(
 
     delayed_call = reactor.callLater(timeout, time_it_out)
 
-    def convert_cancelled(value: failure.Failure):
+    def convert_cancelled(value: Failure) -> Failure:
         # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
@@ -534,7 +534,7 @@ def timeout_deferred(
 
     deferred.addErrback(convert_cancelled)
 
-    def cancel_timeout(result):
+    def cancel_timeout(result: _T) -> _T:
         # stop the pending call to cancel the deferred if it's been fired
         if delayed_call.active():
             delayed_call.cancel()
@@ -542,11 +542,11 @@ def timeout_deferred(
 
     deferred.addBoth(cancel_timeout)
 
-    def success_cb(val):
+    def success_cb(val: _T) -> None:
         if not new_d.called:
             new_d.callback(val)
 
-    def failure_cb(val):
+    def failure_cb(val: Failure) -> None:
         if not new_d.called:
             new_d.errback(val)
 
@@ -557,13 +557,13 @@ def timeout_deferred(
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DoneAwaitable:  # should be: Generic[R]
     """Simple awaitable that returns the provided value."""
 
-    value = attr.ib(type=Any)  # should be: R
+    value: Any  # should be: R
 
-    def __await__(self):
+    def __await__(self) -> Any:
         return self
 
     def __iter__(self) -> "DoneAwaitable":
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index df4d61e4b6..15debd6c46 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -17,7 +17,7 @@ import logging
 import typing
 from enum import Enum, auto
 from sys import intern
-from typing import Callable, Dict, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized
 
 import attr
 from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
     time = auto()
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class CacheMetric:
 
-    _cache = attr.ib()
-    _cache_type = attr.ib(type=str)
-    _cache_name = attr.ib(type=str)
-    _collect_callback = attr.ib(type=Optional[Callable])
+    _cache: Sized
+    _cache_type: str
+    _cache_name: str
+    _collect_callback: Optional[Callable]
 
-    hits = attr.ib(default=0)
-    misses = attr.ib(default=0)
+    hits: int = 0
+    misses: int = 0
     eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
         factory=collections.Counter
     )
-    memory_usage = attr.ib(default=None)
+    memory_usage: Optional[int] = None
 
     def inc_hits(self) -> None:
         self.hits += 1
@@ -89,13 +89,14 @@ class CacheMetric:
         self.memory_usage += memory
 
     def dec_memory_usage(self, memory: int) -> None:
+        assert self.memory_usage is not None
         self.memory_usage -= memory
 
     def clear_memory_usage(self) -> None:
         if self.memory_usage is not None:
             self.memory_usage = 0
 
-    def describe(self):
+    def describe(self) -> List[str]:
         return []
 
     def collect(self) -> None:
@@ -118,8 +119,9 @@ class CacheMetric:
                         self.eviction_size_by_reason[reason]
                     )
                 cache_total.labels(self._cache_name).set(self.hits + self.misses)
-                if getattr(self._cache, "max_size", None):
-                    cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+                max_size = getattr(self._cache, "max_size", None)
+                if max_size:
+                    cache_max_size.labels(self._cache_name).set(max_size)
 
                 if TRACK_MEMORY_USAGE:
                     # self.memory_usage can be None if nothing has been inserted
@@ -193,7 +195,7 @@ KNOWN_KEYS = {
 }
 
 
-def intern_string(string):
+def intern_string(string: Optional[str]) -> Optional[str]:
     """Takes a (potentially) unicode string and interns it if it's ascii"""
     if string is None:
         return None
@@ -204,7 +206,7 @@ def intern_string(string):
         return string
 
 
-def intern_dict(dictionary):
+def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
     """Takes a dictionary and interns well known keys and their values"""
     return {
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -212,7 +214,7 @@ def intern_dict(dictionary):
     }
 
 
-def _intern_known_values(key, value):
+def _intern_known_values(key: str, value: Any) -> Any:
     intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
 
     if key in intern_keys:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index da502aec11..377c9a282a 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     MutableMapping,
     Optional,
+    Sized,
     TypeVar,
     Union,
     cast,
@@ -104,7 +105,13 @@ class DeferredCache(Generic[KT, VT]):
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d) or 1) if iterable else None,
+            size_callback=(
+                (lambda d: len(cast(Sized, d)) or 1)
+                # Argument 1 to "len" has incompatible type "VT"; expected "Sized"
+                # We trust that `VT` is `Sized` when `iterable` is `True`
+                if iterable
+                else None
+            ),
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
             prune_unread_entries=prune_unread_entries,
@@ -289,7 +296,7 @@ class DeferredCache(Generic[KT, VT]):
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
 
-    def invalidate(self, key) -> None:
+    def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
 
         If the cache is backed by a regular dict, then "key" must be of
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index b9dcca17f1..375cd443f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,12 +19,15 @@ import logging
 from typing import (
     Any,
     Callable,
+    Dict,
     Generic,
+    Hashable,
     Iterable,
     Mapping,
     Optional,
     Sequence,
     Tuple,
+    Type,
     TypeVar,
     Union,
     cast,
@@ -32,6 +35,7 @@ from typing import (
 from weakref import WeakValueDictionary
 
 from twisted.internet import defer
+from twisted.python.failure import Failure
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
 
 
 class _CacheDescriptorBase:
-    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        num_args: Optional[int],
+        cache_context: bool = False,
+    ):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __init__(
         self,
-        orig,
+        orig: Callable[..., Any],
         max_entries: int = 1000,
         cache_context: bool = False,
     ):
         super().__init__(orig, num_args=None, cache_context=cache_context)
         self.max_entries = max_entries
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         sentinel = LruCacheDescriptor._Sentinel.sentinel
 
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             invalidate_callback = kwargs.pop("on_invalidate", None)
             callbacks = (invalidate_callback,) if invalidate_callback else ()
 
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
             return r1 + r2
 
     Args:
-        num_args (int): number of positional arguments (excluding ``self`` and
+        num_args: number of positional arguments (excluding ``self`` and
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
     """
 
     def __init__(
         self,
-        orig,
-        max_entries=1000,
-        num_args=None,
-        tree=False,
-        cache_context=False,
-        iterable=False,
+        orig: Callable[..., Any],
+        max_entries: int = 1000,
+        num_args: Optional[int] = None,
+        tree: bool = False,
+        cache_context: bool = False,
+        iterable: bool = False,
         prune_unread_entries: bool = True,
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
     of results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        cached_method_name: str,
+        list_name: str,
+        num_args: Optional[int] = None,
+    ):
         """
         Args:
-            orig (function)
-            cached_method_name (str): The name of the cached method.
-            list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int): number of positional arguments (excluding ``self``,
+            orig
+            cached_method_name: The name of the cached method.
+            list_name: Name of the argument which is the bulk lookup list
+            num_args: number of positional arguments (excluding ``self``,
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 % (self.list_name, cached_method_name)
             )
 
-    def __get__(self, obj, objtype=None):
+    def __get__(
+        self, obj: Optional[Any], objtype: Optional[Type] = None
+    ) -> Callable[..., Any]:
         cached_method = getattr(obj, self.cached_method_name)
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its
             # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
             results = {}
 
-            def update_results_dict(res, arg):
+            def update_results_dict(res: Any, arg: Hashable) -> None:
                 results[arg] = res
 
             # list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             # otherwise a tuple is used.
             if num_args == 1:
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
             else:
                 keylist = list(keyargs)
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     key = arg_to_cache_key(arg)
                     cache.set(key, deferred, callback=invalidate_callback)
 
-                def complete_all(res):
+                def complete_all(res: Dict[Hashable, Any]) -> None:
                     # the wrapped function has completed. It returns a
                     # a dict. We can now resolve the observable deferreds in
                     # the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                         deferreds_map[e].callback(val)
                         results[e] = val
 
-                def errback(f):
+                def errback(f: Failure) -> Failure:
                     # the wrapped function has failed. Invalidate any cache
                     # entries we're supposed to be populating, and fail
                     # their deferreds.
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c3f72aa06d..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
 import attr
 from typing_extensions import Literal
 
+from twisted.internet import defer
+
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util import Clock
@@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
             # Don't bother starting the loop if things never expire
             return
 
-        def f():
+        def f() -> "defer.Deferred[None]":
             return run_as_background_process(
                 "prune_cache_%s" % self._cache_name, self._prune_cache
             )
@@ -157,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
             self[key] = value
             return value
 
-    def _prune_cache(self) -> None:
+    async def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
@@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
         return False
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _CacheEntry:
-    time = attr.ib(type=int)
-    value = attr.ib()
+    time: int
+    value: Any
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a0a7a9de32..eb96f7e665 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,14 +15,15 @@
 import logging
 import threading
 import weakref
+from enum import Enum
 from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
     Callable,
     Collection,
+    Dict,
     Generic,
-    Iterable,
     List,
     Optional,
     Type,
@@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]):
         root: "ListNode[_Node]",
         key: KT,
         value: VT,
-        cache: "weakref.ReferenceType[LruCache]",
+        cache: "weakref.ReferenceType[LruCache[KT, VT]]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
         prune_unread_entries: bool = True,
@@ -270,7 +271,10 @@ class _Node(Generic[KT, VT]):
         removed from all lists.
         """
         cache = self._cache()
-        if not cache or not cache.pop(self.key, None):
+        if (
+            cache is None
+            or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
+        ):
             # `cache.pop` should call `drop_from_lists()`, unless this Node had
             # already been removed from the cache.
             self.drop_from_lists()
@@ -290,6 +294,12 @@ class _Node(Generic[KT, VT]):
             self._global_list_node.update_last_access(clock)
 
 
+class _Sentinel(Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +312,7 @@ class LruCache(Generic[KT, VT]):
         max_size: int,
         cache_name: Optional[str] = None,
         cache_type: Type[Union[dict, TreeCache]] = dict,
-        size_callback: Optional[Callable] = None,
+        size_callback: Optional[Callable[[VT], int]] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
@@ -339,7 +349,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache = cache_type()
+        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -374,7 +384,7 @@ class LruCache(Generic[KT, VT]):
         # creating more each time we create a `_Node`.
         weak_ref_to_self = weakref.ref(self)
 
-        list_root = ListNode[_Node].create_root_node()
+        list_root = ListNode[_Node[KT, VT]].create_root_node()
 
         lock = threading.Lock()
 
@@ -422,7 +432,7 @@ class LruCache(Generic[KT, VT]):
         def add_node(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
-            node = _Node(
+            node: _Node[KT, VT] = _Node(
                 list_root,
                 key,
                 value,
@@ -439,10 +449,10 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node) -> None:
+        def move_node_to_front(node: _Node[KT, VT]) -> None:
             node.move_to_front(real_clock, list_root)
 
-        def delete_node(node: _Node) -> int:
+        def delete_node(node: _Node[KT, VT]) -> int:
             node.drop_from_lists()
 
             deleted_len = 1
@@ -496,7 +506,7 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_set(
-            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
             node = cache.get(key, None)
             if node is not None:
@@ -590,8 +600,6 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
-        self.sentinel = object()
-
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -608,18 +616,18 @@ class LruCache(Generic[KT, VT]):
         self.clear = cache_clear
 
     def __getitem__(self, key: KT) -> VT:
-        result = self.get(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.get(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
         else:
-            return cast(VT, result)
+            return result
 
     def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
     def __delitem__(self, key: KT, value: VT) -> None:
-        result = self.pop(key, self.sentinel)
-        if result is self.sentinel:
+        result = self.pop(key, _Sentinel.sentinel)
+        if result is _Sentinel.sentinel:
             raise KeyError()
 
     def __len__(self) -> int:
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 31097d6439..91837655f8 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -18,12 +18,13 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
 from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
 
-def user_left_room(distributor, user, room_id):
+def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
     distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
@@ -63,7 +64,7 @@ class Distributor:
                 self.pre_registration[name] = []
             self.pre_registration[name].append(observer)
 
-    def fire(self, name: str, *args, **kwargs) -> None:
+    def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
         """Dispatches the given signal to the registered observers.
 
         Runs the observers as a background process. Does not return a deferred.
@@ -95,7 +96,7 @@ class Signal:
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
+    def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.
@@ -103,7 +104,7 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         completed."""
 
-        async def do(observer):
+        async def do(observer: Callable[..., Any]) -> Any:
             try:
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
@@ -120,5 +121,5 @@ class Signal:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py
index a447ce4e55..214eb17fbc 100644
--- a/synapse/util/gai_resolver.py
+++ b/synapse/util/gai_resolver.py
@@ -3,23 +3,52 @@
 # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
 # private class.
 
-
 from socket import (
     AF_INET,
     AF_INET6,
     AF_UNSPEC,
     SOCK_DGRAM,
     SOCK_STREAM,
+    AddressFamily,
+    SocketKind,
     gaierror,
     getaddrinfo,
 )
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    List,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostnameResolver,
+    IHostResolution,
+    IReactorThreads,
+    IResolutionReceiver,
+)
 from twisted.internet.threads import deferToThreadPool
 
+if TYPE_CHECKING:
+    # The types below are copied from
+    # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+    # so that the type hints can match the interfaces.
+    from twisted.python.runtime import platform
+
+    if platform.supportsThreads():
+        from twisted.python.threadpool import ThreadPool
+    else:
+        ThreadPool = object  # type: ignore[misc, assignment]
+
 
 @implementer(IHostResolution)
 class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
     The in-progress resolution of a given hostname.
     """
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         """
         Create a L{HostResolution} with the given name.
         """
         self.name = name
 
-    def cancel(self):
+    def cancel(self) -> NoReturn:
         # IHostResolution.cancel
         raise NotImplementedError()
 
@@ -62,6 +91,17 @@ _socktypeToType = {
 }
 
 
+_GETADDRINFO_RESULT = List[
+    Tuple[
+        AddressFamily,
+        SocketKind,
+        int,
+        str,
+        Union[Tuple[str, int], Tuple[str, int, int, int]],
+    ]
+]
+
+
 @implementer(IHostnameResolver)
 class GAIResolver:
     """
@@ -69,7 +109,12 @@ class GAIResolver:
     L{getaddrinfo} in a thread.
     """
 
-    def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+    def __init__(
+        self,
+        reactor: IReactorThreads,
+        getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+        getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+    ):
         """
         Create a L{GAIResolver}.
         @param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
         )
         self._getaddrinfo = getaddrinfo
 
-    def resolveHostName(
+    # The types on IHostnameResolver is incorrect in Twisted, see
+    # https://twistedmatrix.com/trac/ticket/10276
+    def resolveHostName(  # type: ignore[override]
         self,
-        resolutionReceiver,
-        hostName,
-        portNumber=0,
-        addressTypes=None,
-        transportSemantics="TCP",
-    ):
+        resolutionReceiver: IResolutionReceiver,
+        hostName: str,
+        portNumber: int = 0,
+        addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+        transportSemantics: str = "TCP",
+    ) -> IHostResolution:
         """
         See L{IHostnameResolver.resolveHostName}
         @param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
         ]
         socketType = _transportToSocket[transportSemantics]
 
-        def get():
+        def get() -> _GETADDRINFO_RESULT:
             try:
                 return self._getaddrinfo(
                     hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
         resolutionReceiver.resolutionBegan(resolution)
 
         @d.addCallback
-        def deliverResults(result):
+        def deliverResults(result: _GETADDRINFO_RESULT) -> None:
             for family, socktype, _proto, _cannoname, sockaddr in result:
                 addrType = _afToType[family]
                 resolutionReceiver.addressResolved(
diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index 9f4be757ba..8efbf061aa 100644
--- a/synapse/util/linked_list.py
+++ b/synapse/util/linked_list.py
@@ -84,7 +84,7 @@ class ListNode(Generic[P]):
         # immediately rather than at the next GC.
         self.cache_entry = None
 
-    def move_after(self, node: "ListNode") -> None:
+    def move_after(self, node: "ListNode[P]") -> None:
         """Move this node from its current location in the list to after the
         given node.
         """
@@ -122,7 +122,7 @@ class ListNode(Generic[P]):
         self.prev_node = None
         self.next_node = None
 
-    def _refs_insert_after(self, node: "ListNode") -> None:
+    def _refs_insert_after(self, node: "ListNode[P]") -> None:
         """Internal method to insert the node after the given node."""
 
         # This method should only be called when we're not already in the list.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1e784b3f1f..98ee49af6e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -56,14 +56,22 @@ block_db_sched_duration = Counter(
     "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
 )
 
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+    real_time_max: float
+    real_time_sum: float
+
+
 # Tracks the number of blocks currently active
-in_flight = InFlightGauge(
+in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
     "synapse_util_metrics_block_in_flight",
     "",
     labels=["block_name"],
     sub_metrics=["real_time_max", "real_time_sum"],
 )
 
+
 T = TypeVar("T", bound=Callable[..., Any])
 
 
@@ -180,7 +188,7 @@ class Measure:
         """
         return self._logging_context.get_resource_usage()
 
-    def _update_in_flight(self, metrics) -> None:
+    def _update_in_flight(self, metrics: _InFlightMetric) -> None:
         """Gets called when processing in flight metrics"""
         assert self.start is not None
         duration = self.clock.time() - self.start
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f029432191..ea1032b4fc 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -19,6 +19,8 @@ import string
 from collections.abc import Iterable
 from typing import Optional, Tuple
 
+from netaddr import valid_ipv6
+
 from synapse.api.errors import Codes, SynapseError
 
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -97,7 +99,10 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
         raise ValueError("Invalid server name '%s'" % server_name)
 
 
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+# An approximation of the domain name syntax in RFC 1035, section 2.3.1.
+# NB: "\Z" is not equivalent to "$".
+#     The latter will match the position before a "\n" at the end of a string.
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
 
 
 def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
@@ -122,13 +127,15 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
     if host[0] == "[":
         if host[-1] != "]":
             raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
-        return host, port
 
-    # otherwise it should only be alphanumerics.
-    if not VALID_HOST_REGEX.match(host):
-        raise ValueError(
-            "Server name '%s' contains invalid characters" % (server_name,)
-        )
+        # valid_ipv6 raises when given an empty string
+        ipv6_address = host[1:-1]
+        if not ipv6_address or not valid_ipv6(ipv6_address):
+            raise ValueError(
+                "Server name '%s' is not a valid IPv6 address" % (server_name,)
+            )
+    elif not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' has an invalid format" % (server_name,))
 
     return host, port
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 899ee0adc8..c144ff62c1 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,10 +30,11 @@ def get_version_string(module: ModuleType) -> str:
     If called on a module not in a git checkout will return `__version__`.
 
     Args:
-        module (module)
+        module: The module to check the version of. Must declare a __version__
+            attribute.
 
     Returns:
-        str
+        The module version (as a string).
     """
 
     cached_version = version_cache.get(module)
@@ -44,71 +46,37 @@ def get_version_string(module: ModuleType) -> str:
     version_string = module.__version__  # type: ignore[attr-defined]
 
     try:
-        null = open(os.devnull, "w")
         cwd = os.path.dirname(os.path.abspath(module.__file__))
 
-        try:
-            git_branch = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+        def _run_git_command(prefix: str, *params: str) -> str:
+            try:
+                result = (
+                    subprocess.check_output(
+                        ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd
+                    )
+                    .strip()
+                    .decode("ascii")
                 )
-                .strip()
-                .decode("ascii")
-            )
-            git_branch = "b=" + git_branch
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            # FileNotFoundError can arise when git is not installed
-            git_branch = ""
-
-        try:
-            git_tag = (
-                subprocess.check_output(
-                    ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-            git_tag = "t=" + git_tag
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_tag = ""
-
-        try:
-            git_commit = (
-                subprocess.check_output(
-                    ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-            )
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_commit = ""
-
-        try:
-            dirty_string = "-this_is_a_dirty_checkout"
-            is_dirty = (
-                subprocess.check_output(
-                    ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
-                )
-                .strip()
-                .decode("ascii")
-                .endswith(dirty_string)
-            )
+                return prefix + result
+            except (subprocess.CalledProcessError, FileNotFoundError):
+                return ""
 
-            git_dirty = "dirty" if is_dirty else ""
-        except (subprocess.CalledProcessError, FileNotFoundError):
-            git_dirty = ""
+        git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD")
+        git_tag = _run_git_command("t=", "describe", "--exact-match")
+        git_commit = _run_git_command("", "rev-parse", "--short", "HEAD")
+
+        dirty_string = "-this_is_a_dirty_checkout"
+        is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith(
+            dirty_string
+        )
+        git_dirty = "dirty" if is_dirty else ""
 
         if git_branch or git_tag or git_commit or git_dirty:
             git_version = ",".join(
                 s for s in (git_branch, git_tag, git_commit, git_dirty) if s
             )
 
-            version_string = "%s (%s)" % (
-                # If the __version__ attribute doesn't exist, we'll have failed
-                # loudly above.
-                module.__version__,  # type: ignore[attr-defined]
-                git_version,
-            )
+            version_string = f"{version_string} ({git_version})"
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)