summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/generate_config.py2
-rw-r--r--synapse/_scripts/register_new_matrix_user.py102
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py62
-rw-r--r--synapse/api/auth.py416
-rw-r--r--synapse/api/constants.py18
-rw-r--r--synapse/api/errors.py68
-rw-r--r--synapse/api/ratelimiting.py105
-rw-r--r--synapse/api/room_versions.py55
-rw-r--r--synapse/app/_base.py49
-rw-r--r--synapse/app/admin_cmd.py41
-rw-r--r--synapse/app/appservice.py7
-rw-r--r--synapse/app/client_reader.py7
-rw-r--r--synapse/app/complement_fork_starter.py190
-rw-r--r--synapse/app/event_creator.py7
-rw-r--r--synapse/app/federation_reader.py7
-rw-r--r--synapse/app/federation_sender.py7
-rw-r--r--synapse/app/frontend_proxy.py7
-rw-r--r--synapse/app/generic_worker.py49
-rw-r--r--synapse/app/homeserver.py14
-rw-r--r--synapse/app/media_repository.py7
-rw-r--r--synapse/app/pusher.py7
-rw-r--r--synapse/app/synchrotron.py7
-rw-r--r--synapse/app/user_dir.py7
-rw-r--r--synapse/appservice/api.py14
-rw-r--r--synapse/appservice/scheduler.py9
-rw-r--r--synapse/config/_base.py83
-rw-r--r--synapse/config/account_validity.py2
-rw-r--r--synapse/config/api.py48
-rw-r--r--synapse/config/appservice.py14
-rw-r--r--synapse/config/auth.py75
-rw-r--r--synapse/config/background_updates.py34
-rw-r--r--synapse/config/cache.py102
-rw-r--r--synapse/config/captcha.py27
-rw-r--r--synapse/config/cas.py31
-rw-r--r--synapse/config/consent.py55
-rw-r--r--synapse/config/database.py50
-rw-r--r--synapse/config/emailconfig.py208
-rw-r--r--synapse/config/experimental.py15
-rw-r--r--synapse/config/federation.py39
-rw-r--r--synapse/config/groups.py (renamed from synapse/replication/slave/storage/receipts.py)17
-rw-r--r--synapse/config/jwt.py81
-rw-r--r--synapse/config/key.py99
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/config/metrics.py82
-rw-r--r--synapse/config/modules.py17
-rw-r--r--synapse/config/oembed.py23
-rw-r--r--synapse/config/oidc.py207
-rw-r--r--synapse/config/push.py33
-rw-r--r--synapse/config/ratelimiting.py172
-rw-r--r--synapse/config/redis.py21
-rw-r--r--synapse/config/registration.py322
-rw-r--r--synapse/config/repository.py211
-rw-r--r--synapse/config/retention.py72
-rw-r--r--synapse/config/room.py56
-rw-r--r--synapse/config/room_directory.py66
-rw-r--r--synapse/config/saml2.py192
-rw-r--r--synapse/config/server.py514
-rw-r--r--synapse/config/server_notices.py24
-rw-r--r--synapse/config/sso.py42
-rw-r--r--synapse/config/stats.py13
-rw-r--r--synapse/config/tls.py90
-rw-r--r--synapse/config/tracer.py59
-rw-r--r--synapse/config/user_directory.py39
-rw-r--r--synapse/config/voip.py31
-rw-r--r--synapse/config/workers.py49
-rw-r--r--synapse/crypto/event_signing.py2
-rw-r--r--synapse/event_auth.py299
-rw-r--r--synapse/events/builder.py10
-rw-r--r--synapse/events/snapshot.py7
-rw-r--r--synapse/events/spamcheck.py372
-rw-r--r--synapse/events/third_party_rules.py9
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/events/validator.py4
-rw-r--r--synapse/federation/federation_base.py22
-rw-r--r--synapse/federation/federation_client.py185
-rw-r--r--synapse/federation/federation_server.py75
-rw-r--r--synapse/federation/sender/__init__.py23
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server/_base.py7
-rw-r--r--synapse/handlers/appservice.py11
-rw-r--r--synapse/handlers/auth.py133
-rw-r--r--synapse/handlers/device.py92
-rw-r--r--synapse/handlers/directory.py55
-rw-r--r--synapse/handlers/e2e_keys.py32
-rw-r--r--synapse/handlers/e2e_room_keys.py4
-rw-r--r--synapse/handlers/event_auth.py18
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/federation.py260
-rw-r--r--synapse/handlers/federation_event.py796
-rw-r--r--synapse/handlers/identity.py232
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py215
-rw-r--r--synapse/handlers/oidc.py131
-rw-r--r--synapse/handlers/pagination.py10
-rw-r--r--synapse/handlers/presence.py115
-rw-r--r--synapse/handlers/profile.py19
-rw-r--r--synapse/handlers/receipts.py7
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/relations.py7
-rw-r--r--synapse/handlers/room.py205
-rw-r--r--synapse/handlers/room_list.py23
-rw-r--r--synapse/handlers/room_member.py334
-rw-r--r--synapse/handlers/room_summary.py5
-rw-r--r--synapse/handlers/send_email.py36
-rw-r--r--synapse/handlers/stats.py3
-rw-r--r--synapse/handlers/sync.py423
-rw-r--r--synapse/handlers/typing.py30
-rw-r--r--synapse/handlers/ui_auth/checkers.py21
-rw-r--r--synapse/http/matrixfederationclient.py16
-rw-r--r--synapse/http/server.py97
-rw-r--r--synapse/http/servlet.py25
-rw-r--r--synapse/http/site.py2
-rw-r--r--synapse/logging/opentracing.py305
-rw-r--r--synapse/logging/scopecontextmanager.py35
-rw-r--r--synapse/metrics/__init__.py4
-rw-r--r--synapse/metrics/_legacy_exposition.py (renamed from synapse/metrics/_exposition.py)34
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/module_api/__init__.py86
-rw-r--r--synapse/notifier.py18
-rw-r--r--synapse/push/baserules.py537
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py447
-rw-r--r--synapse/push/clientformat.py68
-rw-r--r--synapse/push/mailer.py7
-rw-r--r--synapse/push/push_rule_evaluator.py27
-rw-r--r--synapse/push/push_tools.py33
-rw-r--r--synapse/push/pusherpool.py27
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py11
-rw-r--r--synapse/replication/http/federation.py3
-rw-r--r--synapse/replication/http/send_event.py3
-rw-r--r--synapse/replication/http/state.py75
-rw-r--r--synapse/replication/slave/storage/_base.py58
-rw-r--r--synapse/replication/slave/storage/account_data.py22
-rw-r--r--synapse/replication/slave/storage/appservice.py25
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py20
-rw-r--r--synapse/replication/slave/storage/devices.py4
-rw-r--r--synapse/replication/slave/storage/directory.py21
-rw-r--r--synapse/replication/slave/storage/events.py3
-rw-r--r--synapse/replication/slave/storage/filtering.py5
-rw-r--r--synapse/replication/slave/storage/profile.py20
-rw-r--r--synapse/replication/slave/storage/push_rule.py1
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/slave/storage/registration.py21
-rw-r--r--synapse/replication/tcp/client.py17
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/res/templates/account_previously_renewed.html13
-rw-r--r--synapse/res/templates/account_renewed.html13
-rw-r--r--synapse/res/templates/add_threepid.html11
-rw-r--r--synapse/res/templates/add_threepid_failure.html15
-rw-r--r--synapse/res/templates/add_threepid_success.html14
-rw-r--r--synapse/res/templates/auth_success.html4
-rw-r--r--synapse/res/templates/invalid_token.html13
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notif_mail.html2
-rw-r--r--synapse/res/templates/password_reset.html7
-rw-r--r--synapse/res/templates/password_reset_confirmation.html8
-rw-r--r--synapse/res/templates/password_reset_failure.html8
-rw-r--r--synapse/res/templates/password_reset_success.html7
-rw-r--r--synapse/res/templates/recaptcha.html4
-rw-r--r--synapse/res/templates/registration.html7
-rw-r--r--synapse/res/templates/registration_failure.html7
-rw-r--r--synapse/res/templates/registration_success.html8
-rw-r--r--synapse/res/templates/registration_token.html6
-rw-r--r--synapse/res/templates/sso_account_deactivated.html4
-rw-r--r--synapse/res/templates/sso_auth_account_details.html5
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html3
-rw-r--r--synapse/res/templates/sso_auth_confirm.html3
-rw-r--r--synapse/res/templates/sso_auth_success.html3
-rw-r--r--synapse/res/templates/sso_error.html3
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html2
-rw-r--r--synapse/res/templates/sso_new_user_consent.html3
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html3
-rw-r--r--synapse/res/templates/terms.html4
-rw-r--r--synapse/rest/admin/_base.py10
-rw-r--r--synapse/rest/admin/media.py6
-rw-r--r--synapse/rest/admin/rooms.py13
-rw-r--r--synapse/rest/admin/users.py16
-rw-r--r--synapse/rest/client/account.py231
-rw-r--r--synapse/rest/client/account_data.py19
-rw-r--r--synapse/rest/client/devices.py27
-rw-r--r--synapse/rest/client/directory.py52
-rw-r--r--synapse/rest/client/keys.py7
-rw-r--r--synapse/rest/client/login.py56
-rw-r--r--synapse/rest/client/models.py69
-rw-r--r--synapse/rest/client/notifications.py6
-rw-r--r--synapse/rest/client/profile.py24
-rw-r--r--synapse/rest/client/pusher.py50
-rw-r--r--synapse/rest/client/read_marker.py58
-rw-r--r--synapse/rest/client/receipts.py20
-rw-r--r--synapse/rest/client/register.py66
-rw-r--r--synapse/rest/client/relations.py13
-rw-r--r--synapse/rest/client/room.py121
-rw-r--r--synapse/rest/client/room_keys.py13
-rw-r--r--synapse/rest/client/sendtodevice.py3
-rw-r--r--synapse/rest/client/sync.py12
-rw-r--r--synapse/rest/client/versions.py4
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py41
-rw-r--r--synapse/rest/media/v1/_base.py40
-rw-r--r--synapse/rest/media/v1/download_resource.py7
-rw-r--r--synapse/rest/media/v1/media_repository.py1
-rw-r--r--synapse/rest/media/v1/media_storage.py9
-rw-r--r--synapse/rest/media/v1/preview_html.py174
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py66
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py47
-rw-r--r--synapse/rest/models.py23
-rw-r--r--synapse/rest/synapse/client/__init__.py3
-rw-r--r--synapse/rest/synapse/client/password_reset.py8
-rw-r--r--synapse/rest/synapse/client/unsubscribe.py64
-rw-r--r--synapse/rest/well_known.py9
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py4
-rw-r--r--synapse/server_notices/server_notices_manager.py12
-rw-r--r--synapse/state/__init__.py324
-rw-r--r--synapse/state/v1.py8
-rw-r--r--synapse/state/v2.py141
-rw-r--r--synapse/static/client/login/index.html3
-rw-r--r--synapse/static/client/register/index.html3
-rw-r--r--synapse/storage/_base.py22
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py195
-rw-r--r--synapse/storage/controllers/state.py65
-rw-r--r--synapse/storage/database.py66
-rw-r--r--synapse/storage/databases/main/__init__.py38
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/appservice.py58
-rw-r--r--synapse/storage/databases/main/cache.py33
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py13
-rw-r--r--synapse/storage/databases/main/devices.py125
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py53
-rw-r--r--synapse/storage/databases/main/event_federation.py9
-rw-r--r--synapse/storage/databases/main/event_push_actions.py947
-rw-r--r--synapse/storage/databases/main/events.py151
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py208
-rw-r--r--synapse/storage/databases/main/events_worker.py306
-rw-r--r--synapse/storage/databases/main/group_server.py34
-rw-r--r--synapse/storage/databases/main/media_repository.py10
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/purge_events.py37
-rw-r--r--synapse/storage/databases/main/push_rule.py130
-rw-r--r--synapse/storage/databases/main/receipts.py92
-rw-r--r--synapse/storage/databases/main/registration.py13
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py175
-rw-r--r--synapse/storage/databases/main/roommember.py502
-rw-r--r--synapse/storage/databases/main/search.py10
-rw-r--r--synapse/storage/databases/main/state.py14
-rw-r--r--synapse/storage/databases/main/stats.py16
-rw-r--r--synapse/storage/databases/main/stream.py54
-rw-r--r--synapse/storage/databases/state/bg_updates.py3
-rw-r--r--synapse/storage/databases/state/store.py165
-rw-r--r--synapse/storage/engines/__init__.py38
-rw-r--r--synapse/storage/engines/postgres.py24
-rw-r--r--synapse/storage/prepare_database.py3
-rw-r--r--synapse/storage/schema/__init__.py14
-rw-r--r--synapse/storage/schema/main/delta/40/event_push_summary.sql7
-rw-r--r--synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.postgres43
-rw-r--r--synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.sqlite47
-rw-r--r--synapse/storage/schema/main/delta/71/01remove_noop_background_updates.sql61
-rw-r--r--synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql18
-rw-r--r--synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql19
-rw-r--r--synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql35
-rw-r--r--synapse/storage/schema/main/delta/72/02event_push_actions_index.sql19
-rw-r--r--synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py47
-rw-r--r--synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql17
-rw-r--r--synapse/storage/schema/main/delta/72/03remove_groups.sql31
-rw-r--r--synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres17
-rw-r--r--synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite40
-rw-r--r--synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql19
-rw-r--r--synapse/storage/state.py8
-rw-r--r--synapse/storage/util/id_generators.py13
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py4
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/types.py3
-rw-r--r--synapse/util/__init__.py6
-rw-r--r--synapse/util/caches/__init__.py16
-rw-r--r--synapse/util/caches/deferred_cache.py346
-rw-r--r--synapse/util/caches/descriptors.py115
-rw-r--r--synapse/util/caches/dictionary_cache.py218
-rw-r--r--synapse/util/caches/lrucache.py147
-rw-r--r--synapse/util/caches/treecache.py41
-rw-r--r--synapse/util/cancellation.py56
-rw-r--r--synapse/util/macaroons.py308
-rw-r--r--synapse/util/ratelimitutils.py214
-rw-r--r--synapse/visibility.py530
286 files changed, 10820 insertions, 8315 deletions
diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py
index 08eb8ef114..06c11c60da 100755
--- a/synapse/_scripts/generate_config.py
+++ b/synapse/_scripts/generate_config.py
@@ -33,7 +33,7 @@ def main() -> None:
     parser.add_argument(
         "--report-stats",
         action="store",
-        help="Whether the generated config reports anonymized usage statistics",
+        help="Whether the generated config reports homeserver usage statistics",
         choices=["yes", "no"],
     )
 
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 092601f530..0c4504d5d8 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -1,6 +1,6 @@
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector
-# Copyright 2021 The Matrix.org Foundation C.I.C.
+# Copyright 2021-22 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,11 +20,22 @@ import hashlib
 import hmac
 import logging
 import sys
-from typing import Callable, Optional
+from typing import Any, Callable, Dict, Optional
 
 import requests
 import yaml
 
+_CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+Conflicting options 'registration_shared_secret' and 'registration_shared_secret_path'
+are both defined in config file.
+"""
+
+_NO_SHARED_SECRET_OPTS_ERROR = """\
+No 'registration_shared_secret' or 'registration_shared_secret_path' defined in config.
+"""
+
+_DEFAULT_SERVER_URL = "http://localhost:8008"
+
 
 def request_registration(
     user: str,
@@ -203,31 +214,104 @@ def main() -> None:
 
     parser.add_argument(
         "server_url",
-        default="https://localhost:8448",
         nargs="?",
-        help="URL to use to talk to the homeserver. Defaults to "
-        " 'https://localhost:8448'.",
+        help="URL to use to talk to the homeserver. By default, tries to find a "
+        "suitable URL from the configuration file. Otherwise, defaults to "
+        f"'{_DEFAULT_SERVER_URL}'.",
     )
 
     args = parser.parse_args()
 
     if "config" in args and args.config:
         config = yaml.safe_load(args.config)
-        secret = config.get("registration_shared_secret", None)
+
+    if args.shared_secret:
+        secret = args.shared_secret
+    else:
+        # argparse should check that we have either config or shared secret
+        assert config
+
+        secret = config.get("registration_shared_secret")
+        secret_file = config.get("registration_shared_secret_path")
+        if secret_file:
+            if secret:
+                print(_CONFLICTING_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
+                sys.exit(1)
+            secret = _read_file(secret_file, "registration_shared_secret_path").strip()
         if not secret:
-            print("No 'registration_shared_secret' defined in config.")
+            print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
             sys.exit(1)
+
+    if args.server_url:
+        server_url = args.server_url
+    elif config:
+        server_url = _find_client_listener(config)
+        if not server_url:
+            server_url = _DEFAULT_SERVER_URL
+            print(
+                "Unable to find a suitable HTTP listener in the configuration file. "
+                f"Trying {server_url} as a last resort.",
+                file=sys.stderr,
+            )
     else:
-        secret = args.shared_secret
+        server_url = _DEFAULT_SERVER_URL
+        print(
+            f"No server url or configuration file given. Defaulting to {server_url}.",
+            file=sys.stderr,
+        )
 
     admin = None
     if args.admin or args.no_admin:
         admin = args.admin
 
     register_new_user(
-        args.user, args.password, args.server_url, secret, admin, args.user_type
+        args.user, args.password, server_url, secret, admin, args.user_type
     )
 
 
+def _read_file(file_path: Any, config_path: str) -> str:
+    """Check the given file exists, and read it into a string
+
+    If it does not, exit with an error indicating the problem
+
+    Args:
+        file_path: the file to be read
+        config_path: where in the configuration file_path came from, so that a useful
+           error can be emitted if it does not exist.
+    Returns:
+        content of the file.
+    """
+    if not isinstance(file_path, str):
+        print(f"{config_path} setting is not a string", file=sys.stderr)
+        sys.exit(1)
+
+    try:
+        with open(file_path) as file_stream:
+            return file_stream.read()
+    except OSError as e:
+        print(f"Error accessing file {file_path}: {e}", file=sys.stderr)
+        sys.exit(1)
+
+
+def _find_client_listener(config: Dict[str, Any]) -> Optional[str]:
+    # try to find a listener in the config. Returns a host:port pair
+    for listener in config.get("listeners", []):
+        if listener.get("type") != "http" or listener.get("tls", False):
+            continue
+
+        if not any(
+            name == "client"
+            for resource in listener.get("resources", [])
+            for name in resource.get("names", [])
+        ):
+            continue
+
+        # TODO: consider bind_addresses
+        return f"http://localhost:{listener['port']}"
+
+    # no suitable listeners?
+    return None
+
+
 if __name__ == "__main__":
     main()
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c753dfa7cb..543bba27c2 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -58,10 +58,10 @@ from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateSt
 from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
 from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
+from synapse.storage.databases.main.event_push_actions import EventPushActionsStore
 from synapse.storage.databases.main.events_bg_updates import (
     EventsBackgroundUpdatesStore,
 )
-from synapse.storage.databases.main.group_server import GroupServerStore
 from synapse.storage.databases.main.media_repository import (
     MediaRepositoryBackgroundUpdateStore,
 )
@@ -166,22 +166,6 @@ IGNORED_TABLES = {
     "ui_auth_sessions",
     "ui_auth_sessions_credentials",
     "ui_auth_sessions_ips",
-    # Groups/communities is no longer supported.
-    "group_attestations_remote",
-    "group_attestations_renewals",
-    "group_invites",
-    "group_roles",
-    "group_room_categories",
-    "group_rooms",
-    "group_summary_roles",
-    "group_summary_room_categories",
-    "group_summary_rooms",
-    "group_summary_users",
-    "group_users",
-    "groups",
-    "local_group_membership",
-    "local_group_updates",
-    "remote_profile_cache",
 }
 
 
@@ -200,6 +184,7 @@ R = TypeVar("R")
 
 
 class Store(
+    EventPushActionsStore,
     ClientIpBackgroundUpdateStore,
     DeviceInboxBackgroundUpdateStore,
     DeviceBackgroundUpdateStore,
@@ -218,7 +203,6 @@ class Store(
     PushRuleStore,
     PusherWorkerStore,
     PresenceBackgroundUpdateStore,
-    GroupServerStore,
 ):
     def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
@@ -268,6 +252,9 @@ class MockHomeserver:
     def get_instance_name(self) -> str:
         return "master"
 
+    def should_send_federation(self) -> bool:
+        return False
+
 
 class Porter:
     def __init__(
@@ -415,12 +402,15 @@ class Porter:
             self.progress.update(table, table_size)  # Mark table as done
             return
 
+        # We sweep over rowids in two directions: one forwards (rowids 1, 2, 3, ...)
+        # and another backwards (rowids 0, -1, -2, ...).
         forward_select = (
             "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
         )
 
         backward_select = (
-            "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
+            "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid DESC LIMIT ?"
+            % (table,)
         )
 
         do_forward = [True]
@@ -618,6 +608,25 @@ class Porter:
                 self.postgres_store.db_pool.updates.has_completed_background_updates()
             )
 
+    @staticmethod
+    def _is_sqlite_autovacuum_enabled(txn: LoggingTransaction) -> bool:
+        """
+        Returns true if auto_vacuum is enabled in SQLite.
+        https://www.sqlite.org/pragma.html#pragma_auto_vacuum
+
+        Vacuuming changes the rowids on rows in the database.
+        Auto-vacuuming is therefore dangerous when used in conjunction with this script.
+
+        Note that the auto_vacuum setting can't be changed without performing
+        a VACUUM after trying to change the pragma.
+        """
+        txn.execute("PRAGMA auto_vacuum")
+        row = txn.fetchone()
+        assert row is not None, "`PRAGMA auto_vacuum` did not give a row."
+        (autovacuum_setting,) = row
+        # 0 means off. 1 means full. 2 means incremental.
+        return autovacuum_setting != 0
+
     async def run(self) -> None:
         """Ports the SQLite database to a PostgreSQL database.
 
@@ -634,6 +643,21 @@ class Porter:
                 allow_outdated_version=True,
             )
 
+            # For safety, ensure auto_vacuums are disabled.
+            if await self.sqlite_store.db_pool.runInteraction(
+                "is_sqlite_autovacuum_enabled", self._is_sqlite_autovacuum_enabled
+            ):
+                end_error = (
+                    "auto_vacuum is enabled in the SQLite database."
+                    " (This is not the default configuration.)\n"
+                    " This script relies on rowids being consistent and must not"
+                    " be used if the database could be vacuumed between re-runs.\n"
+                    " To disable auto_vacuum, you need to stop Synapse and run the following SQL:\n"
+                    " PRAGMA auto_vacuum=off;\n"
+                    " VACUUM;"
+                )
+                return
+
             # Check if all background updates are done, abort if not.
             updates_complete = (
                 await self.sqlite_store.db_pool.updates.has_completed_background_updates()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5a410f805a..9a1aea083f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -20,22 +20,24 @@ from netaddr import IPAddress
 from twisted.web.server import Request
 
 from synapse import event_auth
-from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
     InvalidClientTokenError,
     MissingClientTokenError,
+    UnstableSpecAuthError,
 )
 from synapse.appservice import ApplicationService
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import active_span, force_tracing, start_active_span
-from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, UserID, create_requester
-from synapse.util.caches.lrucache import LruCache
-from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
+from synapse.logging.opentracing import (
+    active_span,
+    force_tracing,
+    start_active_span,
+    trace,
+)
+from synapse.types import Requester, create_requester
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -47,10 +49,6 @@ logger = logging.getLogger(__name__)
 GUEST_DEVICE_ID = "guest_device"
 
 
-class _InvalidMacaroonException(Exception):
-    pass
-
-
 class Auth:
     """
     This class contains functions for authenticating users of our client-server API.
@@ -62,29 +60,23 @@ class Auth:
         self.store = hs.get_datastores().main
         self._account_validity_handler = hs.get_account_validity_handler()
         self._storage_controllers = hs.get_storage_controllers()
-
-        self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
-            10000, "token_cache"
-        )
-
-        self._auth_blocking = AuthBlocking(self.hs)
+        self._macaroon_generator = hs.get_macaroon_generator()
 
         self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
         self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
-        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
         self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
 
     async def check_user_in_room(
         self,
         room_id: str,
-        user_id: str,
+        requester: Requester,
         allow_departed_users: bool = False,
     ) -> Tuple[str, Optional[str]]:
         """Check if the user is in the room, or was at some point.
         Args:
             room_id: The room to check.
 
-            user_id: The user to check.
+            requester: The user making the request, according to the access token.
 
             current_state: Optional map of the current state of the room.
                 If provided then that map is used to check whether they are a
@@ -101,6 +93,7 @@ class Auth:
             membership event ID of the user.
         """
 
+        user_id = requester.user.to_string()
         (
             membership,
             member_event_id,
@@ -119,14 +112,16 @@ class Auth:
                 forgot = await self.store.did_forget(user_id, room_id)
                 if not forgot:
                     return membership, member_event_id
-
-        raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
+        raise UnstableSpecAuthError(
+            403,
+            "User %s not in room %s" % (user_id, room_id),
+            errcode=Codes.NOT_JOINED,
+        )
 
     async def get_user_by_req(
         self,
         request: SynapseRequest,
         allow_guest: bool = False,
-        rights: str = "access",
         allow_expired: bool = False,
     ) -> Requester:
         """Get a registered user's ID.
@@ -135,7 +130,6 @@ class Auth:
             request: An HTTP request with an access_token query parameter.
             allow_guest: If False, will raise an AuthError if the user making the
                 request is a guest.
-            rights: The operation being performed; the access token must allow this
             allow_expired: If True, allow the request through even if the account
                 is expired, or session token lifetime has ended. Note that
                 /login will deliver access tokens regardless of expiration.
@@ -150,7 +144,7 @@ class Auth:
         parent_span = active_span()
         with start_active_span("get_user_by_req"):
             requester = await self._wrapped_get_user_by_req(
-                request, allow_guest, rights, allow_expired
+                request, allow_guest, allow_expired
             )
 
             if parent_span:
@@ -176,7 +170,6 @@ class Auth:
         self,
         request: SynapseRequest,
         allow_guest: bool,
-        rights: str,
         allow_expired: bool,
     ) -> Requester:
         """Helper for get_user_by_req
@@ -189,96 +182,69 @@ class Auth:
 
             access_token = self.get_access_token_from_request(request)
 
-            (
-                user_id,
-                device_id,
-                app_service,
-            ) = await self._get_appservice_user_id_and_device_id(request)
-            if user_id and app_service:
-                if ip_addr and self._track_appservice_user_ips:
-                    await self.store.insert_client_ip(
-                        user_id=user_id,
-                        access_token=access_token,
-                        ip=ip_addr,
-                        user_agent=user_agent,
-                        device_id="dummy-device"
-                        if device_id is None
-                        else device_id,  # stubbed
-                    )
-
-                requester = create_requester(
-                    user_id, app_service=app_service, device_id=device_id
+            # First check if it could be a request from an appservice
+            requester = await self._get_appservice_user(request)
+            if not requester:
+                # If not, it should be from a regular user
+                requester = await self.get_user_by_access_token(
+                    access_token, allow_expired=allow_expired
                 )
 
-                request.requester = user_id
-                return requester
-
-            user_info = await self.get_user_by_access_token(
-                access_token, rights, allow_expired=allow_expired
-            )
-            token_id = user_info.token_id
-            is_guest = user_info.is_guest
-            shadow_banned = user_info.shadow_banned
-
-            # Deny the request if the user account has expired.
-            if not allow_expired:
-                if await self._account_validity_handler.is_user_expired(
-                    user_info.user_id
-                ):
-                    # Raise the error if either an account validity module has determined
-                    # the account has expired, or the legacy account validity
-                    # implementation is enabled and determined the account has expired
-                    raise AuthError(
-                        403,
-                        "User account has expired",
-                        errcode=Codes.EXPIRED_ACCOUNT,
-                    )
-
-            device_id = user_info.device_id
-
-            if access_token and ip_addr:
+                # Deny the request if the user account has expired.
+                # This check is only done for regular users, not appservice ones.
+                if not allow_expired:
+                    if await self._account_validity_handler.is_user_expired(
+                        requester.user.to_string()
+                    ):
+                        # Raise the error if either an account validity module has determined
+                        # the account has expired, or the legacy account validity
+                        # implementation is enabled and determined the account has expired
+                        raise AuthError(
+                            403,
+                            "User account has expired",
+                            errcode=Codes.EXPIRED_ACCOUNT,
+                        )
+
+            if ip_addr and (
+                not requester.app_service or self._track_appservice_user_ips
+            ):
+                # XXX(quenting): I'm 95% confident that we could skip setting the
+                # device_id to "dummy-device" for appservices, and that the only impact
+                # would be some rows which whould not deduplicate in the 'user_ips'
+                # table during the transition
+                recorded_device_id = (
+                    "dummy-device"
+                    if requester.device_id is None and requester.app_service is not None
+                    else requester.device_id
+                )
                 await self.store.insert_client_ip(
-                    user_id=user_info.token_owner,
+                    user_id=requester.authenticated_entity,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
-                    device_id=device_id,
+                    device_id=recorded_device_id,
                 )
+
                 # Track also the puppeted user client IP if enabled and the user is puppeting
                 if (
-                    user_info.user_id != user_info.token_owner
+                    requester.user.to_string() != requester.authenticated_entity
                     and self._track_puppeted_user_ips
                 ):
                     await self.store.insert_client_ip(
-                        user_id=user_info.user_id,
+                        user_id=requester.user.to_string(),
                         access_token=access_token,
                         ip=ip_addr,
                         user_agent=user_agent,
-                        device_id=device_id,
+                        device_id=requester.device_id,
                     )
 
-            if is_guest and not allow_guest:
+            if requester.is_guest and not allow_guest:
                 raise AuthError(
                     403,
                     "Guest access not allowed",
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            # Mark the token as used. This is used to invalidate old refresh
-            # tokens after some time.
-            if not user_info.token_used and token_id is not None:
-                await self.store.mark_access_token_as_used(token_id)
-
-            requester = create_requester(
-                user_info.user_id,
-                token_id,
-                is_guest,
-                shadow_banned,
-                device_id,
-                app_service=app_service,
-                authenticated_entity=user_info.token_owner,
-            )
-
             request.requester = requester
             return requester
         except KeyError:
@@ -315,9 +281,7 @@ class Auth:
                 403, "Application service has not registered this user (%s)" % user_id
             )
 
-    async def _get_appservice_user_id_and_device_id(
-        self, request: Request
-    ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+    async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
         """
         Given a request, reads the request parameters to determine:
         - whether it's an application service that's making this request
@@ -332,15 +296,13 @@ class Auth:
              Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
 
         Returns:
-            3-tuple of
-            (user ID?, device ID?, application service?)
+            the application service `Requester` of that request
 
         Postconditions:
-        - If an application service is returned, so is a user ID
-        - A user ID is never returned without an application service
-        - A device ID is never returned without a user ID or an application service
-        - The returned application service, if present, is permitted to control the
-          returned user ID.
+        - The `app_service` field in the returned `Requester` is set
+        - The `user_id` field in the returned `Requester` is either the application
+          service sender or the controlled user set by the `user_id` URI parameter
+        - The returned application service is permitted to control the returned user ID.
         - The returned device ID, if present, has been checked to be a valid device ID
           for the returned user ID.
         """
@@ -350,12 +312,12 @@ class Auth:
             self.get_access_token_from_request(request)
         )
         if app_service is None:
-            return None, None, None
+            return None
 
         if app_service.ip_range_whitelist:
             ip_address = IPAddress(request.getClientAddress().host)
             if ip_address not in app_service.ip_range_whitelist:
-                return None, None, None
+                return None
 
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
@@ -389,20 +351,19 @@ class Auth:
                     Codes.EXCLUSIVE,
                 )
 
-        return effective_user_id, effective_device_id, app_service
+        return create_requester(
+            effective_user_id, app_service=app_service, device_id=effective_device_id
+        )
 
     async def get_user_by_access_token(
         self,
         token: str,
-        rights: str = "access",
         allow_expired: bool = False,
-    ) -> TokenLookupResult:
+    ) -> Requester:
         """Validate access token and get user_id from it
 
         Args:
             token: The access token to get the user by
-            rights: The operation being performed; the access token must
-                allow this
             allow_expired: If False, raises an InvalidClientTokenError
                 if the token is expired
 
@@ -413,70 +374,69 @@ class Auth:
                 is invalid
         """
 
-        if rights == "access":
-            # First look in the database to see if the access token is present
-            # as an opaque token.
-            r = await self.store.get_user_by_access_token(token)
-            if r:
-                valid_until_ms = r.valid_until_ms
-                if (
-                    not allow_expired
-                    and valid_until_ms is not None
-                    and valid_until_ms < self.clock.time_msec()
-                ):
-                    # there was a valid access token, but it has expired.
-                    # soft-logout the user.
-                    raise InvalidClientTokenError(
-                        msg="Access token has expired", soft_logout=True
-                    )
+        # First look in the database to see if the access token is present
+        # as an opaque token.
+        user_info = await self.store.get_user_by_access_token(token)
+        if user_info:
+            valid_until_ms = user_info.valid_until_ms
+            if (
+                not allow_expired
+                and valid_until_ms is not None
+                and valid_until_ms < self.clock.time_msec()
+            ):
+                # there was a valid access token, but it has expired.
+                # soft-logout the user.
+                raise InvalidClientTokenError(
+                    msg="Access token has expired", soft_logout=True
+                )
 
-                return r
+            # Mark the token as used. This is used to invalidate old refresh
+            # tokens after some time.
+            await self.store.mark_access_token_as_used(user_info.token_id)
+
+            requester = create_requester(
+                user_id=user_info.user_id,
+                access_token_id=user_info.token_id,
+                is_guest=user_info.is_guest,
+                shadow_banned=user_info.shadow_banned,
+                device_id=user_info.device_id,
+                authenticated_entity=user_info.token_owner,
+            )
+
+            return requester
 
         # If the token isn't found in the database, then it could still be a
-        # macaroon, so we check that here.
+        # macaroon for a guest, so we check that here.
         try:
-            user_id, guest = self._parse_and_validate_macaroon(token, rights)
-
-            if rights == "access":
-                if not guest:
-                    # non-guest access tokens must be in the database
-                    logger.warning("Unrecognised access token - not in store.")
-                    raise InvalidClientTokenError()
-
-                # Guest access tokens are not stored in the database (there can
-                # only be one access token per guest, anyway).
-                #
-                # In order to prevent guest access tokens being used as regular
-                # user access tokens (and hence getting around the invalidation
-                # process), we look up the user id and check that it is indeed
-                # a guest user.
-                #
-                # It would of course be much easier to store guest access
-                # tokens in the database as well, but that would break existing
-                # guest tokens.
-                stored_user = await self.store.get_user_by_id(user_id)
-                if not stored_user:
-                    raise InvalidClientTokenError("Unknown user_id %s" % user_id)
-                if not stored_user["is_guest"]:
-                    raise InvalidClientTokenError(
-                        "Guest access token used for regular user"
-                    )
-
-                ret = TokenLookupResult(
-                    user_id=user_id,
-                    is_guest=True,
-                    # all guests get the same device id
-                    device_id=GUEST_DEVICE_ID,
+            user_id = self._macaroon_generator.verify_guest_token(token)
+
+            # Guest access tokens are not stored in the database (there can
+            # only be one access token per guest, anyway).
+            #
+            # In order to prevent guest access tokens being used as regular
+            # user access tokens (and hence getting around the invalidation
+            # process), we look up the user id and check that it is indeed
+            # a guest user.
+            #
+            # It would of course be much easier to store guest access
+            # tokens in the database as well, but that would break existing
+            # guest tokens.
+            stored_user = await self.store.get_user_by_id(user_id)
+            if not stored_user:
+                raise InvalidClientTokenError("Unknown user_id %s" % user_id)
+            if not stored_user["is_guest"]:
+                raise InvalidClientTokenError(
+                    "Guest access token used for regular user"
                 )
-            elif rights == "delete_pusher":
-                # We don't store these tokens in the database
 
-                ret = TokenLookupResult(user_id=user_id, is_guest=False)
-            else:
-                raise RuntimeError("Unknown rights setting %s", rights)
-            return ret
+            return create_requester(
+                user_id=user_id,
+                is_guest=True,
+                # all guests get the same device id
+                device_id=GUEST_DEVICE_ID,
+                authenticated_entity=user_id,
+            )
         except (
-            _InvalidMacaroonException,
             pymacaroons.exceptions.MacaroonException,
             TypeError,
             ValueError,
@@ -488,78 +448,6 @@ class Auth:
             )
             raise InvalidClientTokenError("Invalid access token passed.")
 
-    def _parse_and_validate_macaroon(
-        self, token: str, rights: str = "access"
-    ) -> Tuple[str, bool]:
-        """Takes a macaroon and tries to parse and validate it. This is cached
-        if and only if rights == access and there isn't an expiry.
-
-        On invalid macaroon raises _InvalidMacaroonException
-
-        Returns:
-            (user_id, is_guest)
-        """
-        if rights == "access":
-            cached = self.token_cache.get(token, None)
-            if cached:
-                return cached
-
-        try:
-            macaroon = pymacaroons.Macaroon.deserialize(token)
-        except Exception:  # deserialize can throw more-or-less anything
-            # The access token doesn't look like a macaroon.
-            raise _InvalidMacaroonException()
-
-        try:
-            user_id = get_value_from_macaroon(macaroon, "user_id")
-
-            guest = False
-            for caveat in macaroon.caveats:
-                if caveat.caveat_id == "guest = true":
-                    guest = True
-
-            self.validate_macaroon(macaroon, rights, user_id=user_id)
-        except (
-            pymacaroons.exceptions.MacaroonException,
-            KeyError,
-            TypeError,
-            ValueError,
-        ):
-            raise InvalidClientTokenError("Invalid macaroon passed.")
-
-        if rights == "access":
-            self.token_cache[token] = (user_id, guest)
-
-        return user_id, guest
-
-    def validate_macaroon(
-        self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
-    ) -> None:
-        """
-        validate that a Macaroon is understood by and was signed by this server.
-
-        Args:
-            macaroon: The macaroon to validate
-            type_string: The kind of token required (e.g. "access", "delete_pusher")
-            user_id: The user_id required
-        """
-        v = pymacaroons.Verifier()
-
-        # the verifier runs a test for every caveat on the macaroon, to check
-        # that it is met for the current request. Each caveat must match at
-        # least one of the predicates specified by satisfy_exact or
-        # specify_general.
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = " + type_string)
-        v.satisfy_exact("user_id = %s" % user_id)
-        v.satisfy_exact("guest = true")
-        satisfy_expiry(v, self.clock.time_msec)
-
-        # access_tokens include a nonce for uniqueness: any value is acceptable
-        v.satisfy_general(lambda c: c.startswith("nonce = "))
-
-        v.verify(macaroon, self._macaroon_secret_key)
-
     def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
@@ -569,32 +457,33 @@ class Auth:
         request.requester = create_requester(service.sender, app_service=service)
         return service
 
-    async def is_server_admin(self, user: UserID) -> bool:
+    async def is_server_admin(self, requester: Requester) -> bool:
         """Check if the given user is a local server admin.
 
         Args:
-            user: user to check
+            requester: The user making the request, according to the access token.
 
         Returns:
             True if the user is an admin
         """
-        return await self.store.is_server_admin(user)
+        return await self.store.is_server_admin(requester.user)
 
-    async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
+    async def check_can_change_room_list(
+        self, room_id: str, requester: Requester
+    ) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
 
         Args:
-            room_id
-            user
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
         """
 
-        is_admin = await self.is_server_admin(user)
+        is_admin = await self.is_server_admin(requester)
         if is_admin:
             return True
 
-        user_id = user.to_string()
-        await self.check_user_in_room(room_id, user_id)
+        await self.check_user_in_room(room_id, requester)
 
         # We currently require the user is a "moderator" in the room. We do this
         # by checking if they would (theoretically) be able to change the
@@ -613,7 +502,9 @@ class Auth:
         send_level = event_auth.get_send_level(
             EventTypes.CanonicalAlias, "", power_level_event
         )
-        user_level = event_auth.get_user_power_level(user_id, auth_events)
+        user_level = event_auth.get_user_power_level(
+            requester.user.to_string(), auth_events
+        )
 
         return user_level >= send_level
 
@@ -669,17 +560,18 @@ class Auth:
 
             return query_params[0].decode("ascii")
 
+    @trace
     async def check_user_in_room_or_world_readable(
-        self, room_id: str, user_id: str, allow_departed_users: bool = False
+        self, room_id: str, requester: Requester, allow_departed_users: bool = False
     ) -> Tuple[str, Optional[str]]:
         """Checks that the user is or was in the room or the room is world
         readable. If it isn't then an exception is raised.
 
         Args:
-            room_id: room to check
-            user_id: user to check
-            allow_departed_users: if True, accept users that were previously
-                members but have now departed
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
+            allow_departed_users: If True, accept users that were previously
+                members but have now departed.
 
         Returns:
             Resolves to the current membership of the user in the room and the
@@ -694,7 +586,7 @@ class Auth:
             #  * The user is a guest user, and has joined the room
             # else it will throw.
             return await self.check_user_in_room(
-                room_id, user_id, allow_departed_users=allow_departed_users
+                room_id, requester, allow_departed_users=allow_departed_users
             )
         except AuthError:
             visibility = await self._storage_controllers.state.get_current_state_event(
@@ -706,19 +598,9 @@ class Auth:
                 == HistoryVisibility.WORLD_READABLE
             ):
                 return Membership.JOIN, None
-            raise AuthError(
+            raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
-                % (user_id, room_id),
+                % (requester.user, room_id),
+                errcode=Codes.NOT_JOINED,
             )
-
-    async def check_auth_blocking(
-        self,
-        user_id: Optional[str] = None,
-        threepid: Optional[dict] = None,
-        user_type: Optional[str] = None,
-        requester: Optional[Requester] = None,
-    ) -> None:
-        await self._auth_blocking.check_auth_blocking(
-            user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
-        )
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e1d31cabed..c178ddf070 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -216,11 +216,11 @@ class EventContentFields:
     MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
     # For "insertion" events to indicate what the next batch ID should be in
     # order to connect to it
-    MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
+    MSC2716_NEXT_BATCH_ID: Final = "next_batch_id"
     # Used on "batch" events to indicate which insertion event it connects to
-    MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
+    MSC2716_BATCH_ID: Final = "batch_id"
     # For "marker" events
-    MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
+    MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference"
 
     # The authorising user for joining a restricted room.
     AUTHORISING_USER: Final = "join_authorised_via_users_server"
@@ -257,5 +257,15 @@ class GuestAccess:
 
 class ReceiptTypes:
     READ: Final = "m.read"
-    READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
+    READ_PRIVATE: Final = "m.read.private"
     FULLY_READ: Final = "m.fully_read"
+
+
+class PublicRoomsFilterFields:
+    """Fields in the search filter for `/publicRooms` that we understand.
+
+    As defined in https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3publicrooms
+    """
+
+    GENERIC_SEARCH_TERM: Final = "generic_search_term"
+    ROOM_TYPES: Final = "room_types"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cc7b785472..e6dea89c6d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -26,6 +26,7 @@ from twisted.web import http
 from synapse.util import json_decoder
 
 if typing.TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
     from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -80,6 +81,12 @@ class Codes(str, Enum):
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
 
+    # Part of MSC3848
+    # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
+    ALREADY_JOINED = "ORG.MATRIX.MSC3848.ALREADY_JOINED"
+    NOT_JOINED = "ORG.MATRIX.MSC3848.NOT_JOINED"
+    INSUFFICIENT_POWER = "ORG.MATRIX.MSC3848.INSUFFICIENT_POWER"
+
     # The account has been suspended on the server.
     # By opposition to `USER_DEACTIVATED`, this is a reversible measure
     # that can possibly be appealed and reverted.
@@ -167,7 +174,7 @@ class SynapseError(CodeMessageException):
         else:
             self._additional_fields = dict(additional_fields)
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, **self._additional_fields)
 
 
@@ -213,7 +220,7 @@ class ConsentNotGivenError(SynapseError):
         )
         self._consent_uri = consent_uri
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
 
 
@@ -297,8 +304,45 @@ class AuthError(SynapseError):
     other poorly-defined times.
     """
 
-    def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
-        super().__init__(code, msg, errcode)
+    def __init__(
+        self,
+        code: int,
+        msg: str,
+        errcode: str = Codes.FORBIDDEN,
+        additional_fields: Optional[dict] = None,
+    ):
+        super().__init__(code, msg, errcode, additional_fields)
+
+
+class UnstableSpecAuthError(AuthError):
+    """An error raised when a new error code is being proposed to replace a previous one.
+    This error will return a "org.matrix.unstable.errcode" property with the new error code,
+    with the previous error code still being defined in the "errcode" property.
+
+    This error will include `org.matrix.msc3848.unstable.errcode` in the C-S error body.
+    """
+
+    def __init__(
+        self,
+        code: int,
+        msg: str,
+        errcode: str,
+        previous_errcode: str = Codes.FORBIDDEN,
+        additional_fields: Optional[dict] = None,
+    ):
+        self.previous_errcode = previous_errcode
+        super().__init__(code, msg, errcode, additional_fields)
+
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+        fields = {}
+        if config is not None and config.experimental.msc3848_enabled:
+            fields["org.matrix.msc3848.unstable.errcode"] = self.errcode
+        return cs_error(
+            self.msg,
+            self.previous_errcode,
+            **fields,
+            **self._additional_fields,
+        )
 
 
 class InvalidClientCredentialsError(SynapseError):
@@ -332,8 +376,8 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
         super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
         self._soft_logout = soft_logout
 
-    def error_dict(self) -> "JsonDict":
-        d = super().error_dict()
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+        d = super().error_dict(config)
         d["soft_logout"] = self._soft_logout
         return d
 
@@ -356,7 +400,7 @@ class ResourceLimitError(SynapseError):
         self.limit_type = limit_type
         super().__init__(code, msg, errcode=errcode)
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(
             self.msg,
             self.errcode,
@@ -391,7 +435,7 @@ class InvalidCaptchaError(SynapseError):
         super().__init__(code, msg, errcode)
         self.error_url = error_url
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, error_url=self.error_url)
 
 
@@ -408,7 +452,7 @@ class LimitExceededError(SynapseError):
         super().__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
 
 
@@ -423,7 +467,7 @@ class RoomKeysVersionError(SynapseError):
         super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
         self.current_version = current_version
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, current_version=self.current_version)
 
 
@@ -463,7 +507,7 @@ class IncompatibleRoomVersionError(SynapseError):
 
         self._room_version = room_version
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, room_version=self._room_version)
 
 
@@ -509,7 +553,7 @@ class UnredactedContentDeletedError(SynapseError):
         )
         self.content_keep_ms = content_keep_ms
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         extra = {}
         if self.content_keep_ms is not None:
             extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms}
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 849c18ceda..044c7d4926 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,7 +17,7 @@ from collections import OrderedDict
 from typing import Hashable, Optional, Tuple
 
 from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import RateLimitConfig
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.storage.databases.main import DataStore
 from synapse.types import Requester
 from synapse.util import Clock
@@ -27,6 +27,33 @@ class Ratelimiter:
     """
     Ratelimit actions marked by arbitrary keys.
 
+    (Note that the source code speaks of "actions" and "burst_count" rather than
+    "tokens" and a "bucket_size".)
+
+    This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
+    containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
+    permitted requests for that key. Each bucket starts empty, and gradually leaks
+    tokens at a rate of `rate_hz`.
+
+    Upon an incoming request, we must determine:
+    - the key that this request falls under (which bucket to inspect), and
+    - the cost C of this request in tokens.
+    Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
+    the request is permitted and `cost` tokens are added to the bucket.
+    Otherwise the request is denied, and the bucket continues to hold T tokens.
+
+    This means that the limiter enforces an average request frequency of `rate_hz`,
+    while accumulating a buffer of up to `burst_count` requests which can be consumed
+    instantaneously.
+
+    The tricky bit is the leaking. We do not want to have a periodic process which
+    leaks every bucket! Instead, we track
+    - the time point when the bucket was last completely empty, and
+    - how many tokens have added to the bucket permitted since then.
+    Then for each incoming request, we can calculate how many tokens have leaked
+    since this time point, and use that to decide if we should accept or reject the
+    request.
+
     Args:
         clock: A homeserver clock, for retrieving the current time
         rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ class Ratelimiter:
         self.burst_count = burst_count
         self.store = store
 
-        # A ordered dictionary keeping track of actions, when they were last
-        # performed and how often. Each entry is a mapping from a key of arbitrary type
-        # to a tuple representing:
-        #   * How many times an action has occurred since a point in time
-        #   * The point in time
-        #   * The rate_hz of this particular entry. This can vary per request
+        # An ordered dictionary representing the token buckets tracked by this rate
+        # limiter. Each entry maps a key of arbitrary type to a tuple representing:
+        #   * The number of tokens currently in the bucket,
+        #   * The time point when the bucket was last completely empty, and
+        #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided."""
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket."""
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -88,11 +131,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -121,13 +160,16 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start
         performed_count = action_count - time_delta * rate_hz
         if performed_count < 0:
             performed_count = 0
+
+            # Reset the start time and forgive all actions
+            action_count = 0
             time_start = time_now_s
 
         # This check would be easier read as performed_count + n_actions > burst_count,
@@ -140,7 +182,7 @@ class Ratelimiter:
         else:
             # We haven't reached our limit yet
             allowed = True
-            action_count = performed_count + n_actions
+            action_count = action_count + n_actions
 
         if update:
             self.actions[key] = (action_count, time_start, rate_hz)
@@ -161,6 +203,37 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
+    def record_action(
+        self,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
+        n_actions: int = 1,
+        _time_now_s: Optional[float] = None,
+    ) -> None:
+        """Record that an action(s) took place, even if they violate the rate limit.
+
+        This is useful for tracking the frequency of events that happen across
+        federation which we still want to impose local rate limits on. For instance, if
+        we are alice.com monitoring a particular room, we cannot prevent bob.com
+        from joining users to that room. However, we can track the number of recent
+        joins in the room and refuse to serve new joins ourselves if there have been too
+        many in the room across both homeservers.
+
+        Args:
+            requester: The requester that is doing the action, if any.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
+            n_actions: The number of times the user wants to do this action. If the user
+                cannot do all of the actions, the user's action count is not incremented
+                at all.
+            _time_now_s: The current time. Optional, defaults to the current time according
+                to self.clock. Only used by tests.
+        """
+        key = self._get_key(requester, key)
+        time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+        action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
+        self.actions[key] = (action_count + n_actions, time_start, rate_hz)
+
     def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
@@ -241,8 +314,8 @@ class RequestRatelimiter:
         self,
         store: DataStore,
         clock: Clock,
-        rc_message: RateLimitConfig,
-        rc_admin_redaction: Optional[RateLimitConfig],
+        rc_message: RatelimitSettings,
+        rc_admin_redaction: Optional[RatelimitSettings],
     ):
         self.store = store
         self.clock = clock
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 3f85d61b46..a0e4ab6db6 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -84,6 +84,8 @@ class RoomVersion:
     # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
     # knocks and restricted join rules into the same join condition.
     msc3787_knock_restricted_join_rule: bool
+    # MSC3667: Enforce integer power levels
+    msc3667_int_only_power_levels: bool
 
 
 class RoomVersions:
@@ -103,6 +105,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V2 = RoomVersion(
         "2",
@@ -120,6 +123,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V3 = RoomVersion(
         "3",
@@ -137,6 +141,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V4 = RoomVersion(
         "4",
@@ -154,6 +159,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V5 = RoomVersion(
         "5",
@@ -171,6 +177,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V6 = RoomVersion(
         "6",
@@ -188,6 +195,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -205,6 +213,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V7 = RoomVersion(
         "7",
@@ -222,6 +231,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V8 = RoomVersion(
         "8",
@@ -239,6 +249,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V9 = RoomVersion(
         "9",
@@ -256,9 +267,10 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
-    MSC2716v3 = RoomVersion(
-        "org.matrix.msc2716v3",
+    MSC3787 = RoomVersion(
+        "org.matrix.msc3787",
         RoomDisposition.UNSTABLE,
         EventFormatVersions.V3,
         StateResolutionVersions.V2,
@@ -267,16 +279,17 @@ class RoomVersions:
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
         msc2176_redaction_rules=False,
-        msc3083_join_rules=False,
-        msc3375_redaction_rules=False,
+        msc3083_join_rules=True,
+        msc3375_redaction_rules=True,
         msc2403_knocking=True,
-        msc2716_historical=True,
-        msc2716_redactions=True,
-        msc3787_knock_restricted_join_rule=False,
+        msc2716_historical=False,
+        msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=False,
     )
-    MSC3787 = RoomVersion(
-        "org.matrix.msc3787",
-        RoomDisposition.UNSTABLE,
+    V10 = RoomVersion(
+        "10",
+        RoomDisposition.STABLE,
         EventFormatVersions.V3,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
@@ -290,6 +303,25 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=True,
+    )
+    MSC2716v4 = RoomVersion(
+        "org.matrix.msc2716v4",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
+        msc3375_redaction_rules=False,
+        msc2403_knocking=True,
+        msc2716_historical=True,
+        msc2716_redactions=True,
+        msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
 
 
@@ -306,8 +338,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.V7,
         RoomVersions.V8,
         RoomVersions.V9,
-        RoomVersions.MSC2716v3,
         RoomVersions.MSC3787,
+        RoomVersions.V10,
+        RoomVersions.MSC2716v4,
     )
 }
 
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 84e389a6cd..4742435d3b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -106,7 +106,9 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs)
 def start_worker_reactor(
     appname: str,
     config: HomeServerConfig,
-    run_command: Callable[[], None] = reactor.run,
+    # Use a lambda to avoid binding to a given reactor at import time.
+    # (needed when synapse.app.complement_fork_starter is being used)
+    run_command: Callable[[], None] = lambda: reactor.run(),
 ) -> None:
     """Run the reactor in the main process
 
@@ -141,7 +143,9 @@ def start_reactor(
     daemonize: bool,
     print_pidfile: bool,
     logger: logging.Logger,
-    run_command: Callable[[], None] = reactor.run,
+    # Use a lambda to avoid binding to a given reactor at import time.
+    # (needed when synapse.app.complement_fork_starter is being used)
+    run_command: Callable[[], None] = lambda: reactor.run(),
 ) -> None:
     """Run the reactor in the main process
 
@@ -262,15 +266,48 @@ def register_start(
     reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
 
 
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(
+    bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool
+) -> None:
     """
     Start Prometheus metrics server.
     """
-    from synapse.metrics import RegistryProxy, start_http_server
+    from prometheus_client import start_http_server as start_http_server_prometheus
+
+    from synapse.metrics import (
+        RegistryProxy,
+        start_http_server as start_http_server_legacy,
+    )
 
     for host in bind_addresses:
         logger.info("Starting metrics listener on %s:%d", host, port)
-        start_http_server(port, addr=host, registry=RegistryProxy)
+        if enable_legacy_metric_names:
+            start_http_server_legacy(port, addr=host, registry=RegistryProxy)
+        else:
+            _set_prometheus_client_use_created_metrics(False)
+            start_http_server_prometheus(port, addr=host, registry=RegistryProxy)
+
+
+def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
+    """
+    Sets whether prometheus_client should expose `_created`-suffixed metrics for
+    all gauges, histograms and summaries.
+    There is no programmatic way to disable this without poking at internals;
+    the proper way is to use an environment variable which prometheus_client
+    loads at import time.
+
+    The motivation for disabling these `_created` metrics is that they're
+    a waste of space as they're not useful but they take up space in Prometheus.
+    """
+
+    import prometheus_client.metrics
+
+    if hasattr(prometheus_client.metrics, "_use_created"):
+        prometheus_client.metrics._use_created = new_value
+    else:
+        logger.error(
+            "Can't disable `_created` metrics in prometheus_client (brittle hack broken?)"
+        )
 
 
 def listen_manhole(
@@ -450,7 +487,7 @@ async def start(hs: "HomeServer") -> None:
     # before we start the listeners.
     module_api = hs.get_module_api()
     for module, config in hs.config.modules.loaded_modules:
-        m = module(config=config, api=module_api)
+        m = module(config, module_api)
         logger.info("Loaded module %s", m)
 
     load_legacy_spam_checkers(hs)
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 561621a285..8a583d3ec6 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -28,18 +28,22 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
 from synapse.events import EventBase
 from synapse.handlers.admin import ExfiltrationWriter
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+)
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 from synapse.types import StateMap
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.logcontext import LoggingContext
@@ -48,19 +52,30 @@ logger = logging.getLogger("synapse.app.admin_cmd")
 
 
 class AdminCmdSlavedStore(
-    SlavedReceiptsStore,
-    SlavedAccountDataStore,
-    SlavedApplicationServiceStore,
-    SlavedRegistrationStore,
     SlavedFilteringStore,
-    SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedPushRuleStore,
     SlavedEventStore,
-    BaseSlavedStore,
+    TagsWorkerStore,
+    DeviceInboxWorkerStore,
+    AccountDataWorkerStore,
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+    RegistrationWorkerStore,
+    ReceiptsWorkerStore,
     RoomWorkerStore,
 ):
-    pass
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        # Annoyingly `filter_events_for_client` assumes that this exists. We
+        # should refactor it to take a `Clock` directly.
+        self.clock = hs.get_clock()
 
 
 class AdminCmdServer(HomeServer):
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py
new file mode 100644
index 0000000000..89eb07df27
--- /dev/null
+++ b/synapse/app/complement_fork_starter.py
@@ -0,0 +1,190 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ## What this script does
+#
+# This script spawns multiple workers, whilst only going through the code loading
+# process once. The net effect is that start-up time for a swarm of workers is
+# reduced, particularly in CPU-constrained environments.
+#
+# Before the workers are spawned, the database is prepared in order to avoid the
+# workers racing.
+#
+# ## Stability
+#
+# This script is only intended for use within the Synapse images for the
+# Complement test suite.
+# There are currently no stability guarantees whatsoever; especially not about:
+# - whether it will continue to exist in future versions;
+# - the format of its command-line arguments; or
+# - any details about its behaviour or principles of operation.
+#
+# ## Usage
+#
+# The first argument should be the path to the database configuration, used to
+# set up the database. The rest of the arguments are used as follows:
+# Each worker is specified as an argument group (each argument group is
+# separated by '--').
+# The first argument in each argument group is the Python module name of the application
+# to start. Further arguments are then passed to that module as-is.
+#
+# ## Example
+#
+#   python -m synapse.app.complement_fork_starter path_to_db_config.yaml \
+#     synapse.app.homeserver [args..] -- \
+#     synapse.app.generic_worker [args..] -- \
+#   ...
+#     synapse.app.generic_worker [args..]
+#
+import argparse
+import importlib
+import itertools
+import multiprocessing
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet.main import installReactor
+
+
+class ProxiedReactor:
+    """
+    Twisted tracks the 'installed' reactor as a global variable.
+    (Actually, it does some module trickery, but the effect is similar.)
+
+    The default EpollReactor is buggy if it's created before a process is
+    forked, then used in the child.
+    See https://twistedmatrix.com/trac/ticket/4759#comment:17.
+
+    However, importing certain Twisted modules will automatically create and
+    install a reactor if one hasn't already been installed.
+    It's not normally possible to re-install a reactor.
+
+    Given the goal of launching workers with fork() to only import the code once,
+    this presents a conflict.
+    Our work around is to 'install' this ProxiedReactor which prevents Twisted
+    from creating and installing one, but which lets us replace the actual reactor
+    in use later on.
+    """
+
+    def __init__(self) -> None:
+        self.___reactor_target: Any = None
+
+    def _install_real_reactor(self, new_reactor: Any) -> None:
+        """
+        Install a real reactor for this ProxiedReactor to forward lookups onto.
+
+        This method is specific to our ProxiedReactor and should not clash with
+        any names used on an actual Twisted reactor.
+        """
+        self.___reactor_target = new_reactor
+
+    def __getattr__(self, attr_name: str) -> Any:
+        return getattr(self.___reactor_target, attr_name)
+
+
+def _worker_entrypoint(
+    func: Callable[[], None], proxy_reactor: ProxiedReactor, args: List[str]
+) -> None:
+    """
+    Entrypoint for a forked worker process.
+
+    We just need to set up the command-line arguments, create our real reactor
+    and then kick off the worker's main() function.
+    """
+
+    sys.argv = args
+
+    from twisted.internet.epollreactor import EPollReactor
+
+    proxy_reactor._install_real_reactor(EPollReactor())
+    func()
+
+
+def main() -> None:
+    """
+    Entrypoint for the forking launcher.
+    """
+    parser = argparse.ArgumentParser()
+    parser.add_argument("db_config", help="Path to database config file")
+    parser.add_argument(
+        "args",
+        nargs="...",
+        help="Argument groups separated by `--`. "
+        "The first argument of each group is a Synapse app name. "
+        "Subsequent arguments are passed through.",
+    )
+    ns = parser.parse_args()
+
+    # Split up the subsequent arguments into each workers' arguments;
+    # `--` is our delimiter of choice.
+    args_by_worker: List[List[str]] = [
+        list(args)
+        for cond, args in itertools.groupby(ns.args, lambda ele: ele != "--")
+        if cond and args
+    ]
+
+    # Prevent Twisted from installing a shared reactor that all the workers will
+    # inherit when we fork(), by installing our own beforehand.
+    proxy_reactor = ProxiedReactor()
+    installReactor(proxy_reactor)
+
+    # Import the entrypoints for all the workers.
+    worker_functions = []
+    for worker_args in args_by_worker:
+        worker_module = importlib.import_module(worker_args[0])
+        worker_functions.append(worker_module.main)
+
+    # We need to prepare the database first as otherwise all the workers will
+    # try to create a schema version table and some will crash out.
+    from synapse._scripts import update_synapse_database
+
+    update_proc = multiprocessing.Process(
+        target=_worker_entrypoint,
+        args=(
+            update_synapse_database.main,
+            proxy_reactor,
+            [
+                "update_synapse_database",
+                "--database-config",
+                ns.db_config,
+                "--run-background-updates",
+            ],
+        ),
+    )
+    print("===== PREPARING DATABASE =====", file=sys.stderr)
+    update_proc.start()
+    update_proc.join()
+    print("===== PREPARED DATABASE =====", file=sys.stderr)
+
+    # At this point, we've imported all the main entrypoints for all the workers.
+    # Now we basically just fork() out to create the workers we need.
+    # Because we're using fork(), all the workers get a clone of this launcher's
+    # memory space and don't need to repeat the work of loading the code!
+    # Instead of using fork() directly, we use the multiprocessing library,
+    # which uses fork() on Unix platforms.
+    processes = []
+    for (func, worker_args) in zip(worker_functions, args_by_worker):
+        process = multiprocessing.Process(
+            target=_worker_entrypoint, args=(func, proxy_reactor, worker_args)
+        )
+        process.start()
+        processes.append(process)
+
+    # Be a good parent and wait for our children to die before exiting.
+    for process in processes:
+        process.join()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 885454ed44..7995d99825 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 4a987fb759..5e3825fca6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -48,20 +48,12 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.client import (
     account_data,
@@ -100,8 +92,15 @@ from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import well_known_resource
 from synapse.server import HomeServer
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+)
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
 from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.lock import LockStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
@@ -110,11 +109,15 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.databases.main.presence import PresenceStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.storage.databases.main.room_batch import RoomBatchStore
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.session import SessionStore
 from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
 from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -227,11 +230,11 @@ class GenericWorkerSlavedStore(
     UIAuthWorkerStore,
     EndToEndRoomKeyStore,
     PresenceStore,
-    SlavedDeviceInboxStore,
+    DeviceInboxWorkerStore,
     SlavedDeviceStore,
-    SlavedReceiptsStore,
     SlavedPushRuleStore,
-    SlavedAccountDataStore,
+    TagsWorkerStore,
+    AccountDataWorkerStore,
     SlavedPusherStore,
     CensorEventsStore,
     ClientIpWorkerStore,
@@ -239,19 +242,20 @@ class GenericWorkerSlavedStore(
     SlavedKeyStore,
     RoomWorkerStore,
     RoomBatchStore,
-    DirectoryStore,
-    SlavedApplicationServiceStore,
-    SlavedRegistrationStore,
-    SlavedProfileStore,
+    DirectoryWorkerStore,
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+    ProfileWorkerStore,
     SlavedFilteringStore,
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
     ServerMetricsStore,
+    ReceiptsWorkerStore,
+    RegistrationWorkerStore,
     SearchStore,
     TransactionWorkerStore,
     LockStore,
     SessionStore,
-    BaseSlavedStore,
 ):
     # Properties that multiple storage classes define. Tell mypy what the
     # expected type is.
@@ -408,7 +412,11 @@ class GenericWorkerServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 logger.warning("Unsupported listener type: %s", listener.type)
 
@@ -437,6 +445,13 @@ def start(config_options: List[str]) -> None:
         "synapse.app.user_dir",
     )
 
+    if config.experimental.faster_joins_enabled:
+        raise ConfigError(
+            "You have enabled the experimental `faster_joins` config option, but it is "
+            "not compatible with worker deployments yet. Please disable `faster_joins` "
+            "or run Synapse as a single process deployment instead."
+        )
+
     synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
     synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 745e704141..e57a926032 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -44,7 +44,6 @@ from synapse.app._base import (
     register_start,
 )
 from synapse.config._base import ConfigError, format_config_error
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
@@ -202,7 +201,7 @@ class SynapseHomeServer(HomeServer):
                 }
             )
 
-            if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+            if self.config.email.can_verify_email:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
                 )
@@ -220,7 +219,10 @@ class SynapseHomeServer(HomeServer):
             resources.update({"/_matrix/consent": consent_resource})
 
         if name == "federation":
-            resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+            federation_resource: Resource = TransportLayerServer(self)
+            if compress:
+                federation_resource = gz_wrap(federation_resource)
+            resources.update({FEDERATION_PREFIX: federation_resource})
 
         if name == "openid":
             resources.update(
@@ -305,7 +307,11 @@ class SynapseHomeServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 # this shouldn't happen, as the listener type should have been checked
                 # during parsing
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index de1bcee0a7..b6aed651ed 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 14bde27179..34f23c4e5d 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -17,6 +17,11 @@ import sys
 from synapse.app.generic_worker import start
 from synapse.util.logcontext import LoggingContext
 
-if __name__ == "__main__":
+
+def main() -> None:
     with LoggingContext("main"):
         start(sys.argv[1:])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index df1c214462..0963fb3bb4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -53,6 +53,18 @@ sent_events_counter = Counter(
     "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
 )
 
+sent_ephemeral_counter = Counter(
+    "synapse_appservice_api_sent_ephemeral",
+    "Number of ephemeral events sent to the AS",
+    ["service"],
+)
+
+sent_todevice_counter = Counter(
+    "synapse_appservice_api_sent_todevice",
+    "Number of todevice messages sent to the AS",
+    ["service"],
+)
+
 HOUR_IN_MS = 60 * 60 * 1000
 
 
@@ -310,6 +322,8 @@ class ApplicationServiceApi(SimpleHttpClient):
                 )
             sent_transactions_counter.labels(service.id).inc()
             sent_events_counter.labels(service.id).inc(len(serialized_events))
+            sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
+            sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
             return True
         except CodeMessageException as e:
             logger.warning(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index de5e5216c2..430ffbcd1f 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -319,7 +319,9 @@ class _ServiceQueuer:
         rooms_of_interesting_users.update(event.room_id for event in events)
         # EDUs
         rooms_of_interesting_users.update(
-            ephemeral["room_id"] for ephemeral in ephemerals
+            ephemeral["room_id"]
+            for ephemeral in ephemerals
+            if ephemeral.get("room_id") is not None
         )
 
         # Look up the AS users in those rooms
@@ -329,8 +331,9 @@ class _ServiceQueuer:
             )
 
         # Add recipients of to-device messages.
-        # device_message["user_id"] is the ID of the recipient.
-        users.update(device_message["user_id"] for device_message in to_device_messages)
+        users.update(
+            device_message["to_user_id"] for device_message in to_device_messages
+        )
 
         # Compute and return the counts / fallback key usage states
         otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 42364fc133..1f6362aedd 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@ import argparse
 import errno
 import logging
 import os
+import re
 from collections import OrderedDict
+from enum import Enum, auto
 from hashlib import sha256
 from textwrap import dedent
 from typing import (
@@ -96,16 +98,16 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
 # We split these messages out to allow packages to override with package
 # specific instructions.
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
-Please opt in or out of reporting anonymized homeserver usage statistics, by
-setting the `report_stats` key in your config file to either True or False.
+Please opt in or out of reporting homeserver usage statistics, by setting
+the `report_stats` key in your config file to either True or False.
 """
 
 MISSING_REPORT_STATS_SPIEL = """\
 We would really appreciate it if you could help our project out by reporting
-anonymized usage statistics from your homeserver. Only very basic aggregate
-data (e.g. number of users) will be reported, but it helps us to track the
-growth of the Matrix community, and helps us to make Matrix a success, as well
-as to convince other networks that they should peer with us.
+homeserver usage statistics from your homeserver. Your homeserver's server name,
+along with very basic aggregate data (e.g. number of users) will be reported. But
+it helps us to track the growth of the Matrix community, and helps us to make Matrix
+a success, as well as to convince other networks that they should peer with us.
 
 Thank you.
 """
@@ -123,7 +125,10 @@ CONFIG_FILE_HEADER = """\
 # should have the same indentation.
 #
 # [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html
-
+#
+# For more information on how to configure Synapse, including a complete accounting of
+# each option, go to docs/usage/configuration/config_documentation.md or
+# https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html
 """
 
 
@@ -470,7 +475,7 @@ class RootConfig:
             The yaml config file
         """
 
-        return CONFIG_FILE_HEADER + "\n\n".join(
+        conf = CONFIG_FILE_HEADER + "\n".join(
             dedent(conf)
             for conf in self.invoke_all(
                 "generate_config_section",
@@ -485,6 +490,8 @@ class RootConfig:
                 tls_private_key_path=tls_private_key_path,
             ).values()
         )
+        conf = re.sub("\n{2,}", "\n", conf)
+        return conf
 
     @classmethod
     def load_config(
@@ -597,25 +604,51 @@ class RootConfig:
             " may specify directories containing *.yaml files.",
         )
 
-        generate_group = parser.add_argument_group("Config generation")
-        generate_group.add_argument(
+        # we nest the mutually-exclusive group inside another group so that the help
+        # text shows them in their own group.
+        generate_mode_group = parser.add_argument_group(
+            "Config generation mode",
+        )
+        generate_mode_exclusive = generate_mode_group.add_mutually_exclusive_group()
+        generate_mode_exclusive.add_argument(
+            # hidden option to make the type and default work
+            "--generate-mode",
+            help=argparse.SUPPRESS,
+            type=_ConfigGenerateMode,
+            default=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+        )
+        generate_mode_exclusive.add_argument(
             "--generate-config",
-            action="store_true",
             help="Generate a config file, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            dest="generate_mode",
         )
-        generate_group.add_argument(
+        generate_mode_exclusive.add_argument(
             "--generate-missing-configs",
             "--generate-keys",
-            action="store_true",
             help="Generate any missing additional config files, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+            dest="generate_mode",
+        )
+        generate_mode_exclusive.add_argument(
+            "--generate-missing-and-run",
+            help="Generate any missing additional config files, then run. This is the "
+            "default behaviour.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+            dest="generate_mode",
         )
+
+        generate_group = parser.add_argument_group("Details for --generate-config")
         generate_group.add_argument(
             "-H", "--server-name", help="The server name to generate a config file for."
         )
         generate_group.add_argument(
             "--report-stats",
             action="store",
-            help="Whether the generated config reports anonymized usage statistics.",
+            help="Whether the generated config reports homeserver usage statistics.",
             choices=["yes", "no"],
         )
         generate_group.add_argument(
@@ -664,11 +697,12 @@ class RootConfig:
         config_dir_path = os.path.abspath(config_dir_path)
         data_dir_path = os.getcwd()
 
-        generate_missing_configs = config_args.generate_missing_configs
-
         obj = cls(config_files)
 
-        if config_args.generate_config:
+        if (
+            config_args.generate_mode
+            == _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT
+        ):
             if config_args.report_stats is None:
                 parser.error(
                     "Please specify either --report-stats=yes or --report-stats=no\n\n"
@@ -726,11 +760,14 @@ class RootConfig:
                     )
                     % (config_path,)
                 )
-                generate_missing_configs = True
 
         config_dict = read_config_files(config_files)
-        if generate_missing_configs:
-            obj.generate_missing_files(config_dict, config_dir_path)
+        obj.generate_missing_files(config_dict, config_dir_path)
+
+        if config_args.generate_mode in (
+            _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            _ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+        ):
             return None
 
         obj.parse_config_dict(
@@ -959,6 +996,12 @@ def read_file(file_path: Any, config_path: Iterable[str]) -> str:
         raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
 
 
+class _ConfigGenerateMode(Enum):
+    GENERATE_MISSING_AND_RUN = auto()
+    GENERATE_MISSING_AND_EXIT = auto()
+    GENERATE_EVERYTHING_AND_EXIT = auto()
+
+
 __all__ = [
     "Config",
     "RootConfig",
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index d1335e77cd..b3972ede96 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -23,7 +23,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'account_validity' section. Support for this setting has been deprecated and will be
 removed in a future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 2cc6305340..e46728e73f 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -31,54 +31,6 @@ class ApiConfig(Config):
         self.room_prejoin_state = list(self._get_prejoin_state_types(config))
         self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
 
-    def generate_config_section(cls, **kwargs: Any) -> str:
-        formatted_default_state_types = "\n".join(
-            "           # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
-        )
-
-        return """\
-        ## API Configuration ##
-
-        # Controls for the state that is shared with users who receive an invite
-        # to a room
-        #
-        room_prejoin_state:
-           # By default, the following state event types are shared with users who
-           # receive invites to the room:
-           #
-%(formatted_default_state_types)s
-           #
-           # Uncomment the following to disable these defaults (so that only the event
-           # types listed in 'additional_event_types' are shared). Defaults to 'false'.
-           #
-           #disable_default_event_types: true
-
-           # Additional state event types to share with users when they are invited
-           # to a room.
-           #
-           # By default, this list is empty (so only the default event types are shared).
-           #
-           #additional_event_types:
-           #  - org.example.custom.event.type
-
-        # We record the IP address of clients used to access the API for various
-        # reasons, including displaying it to the user in the "Where you're signed in"
-        # dialog.
-        #
-        # By default, when puppeting another user via the admin API, the client IP
-        # address is recorded against the user who created the access token (ie, the
-        # admin user), and *not* the puppeted user.
-        #
-        # Uncomment the following to also record the IP address against the puppeted
-        # user. (This also means that the puppeted user will count as an "active" user
-        # for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc
-        # above.)
-        #
-        #track_puppeted_user_ips: true
-        """ % {
-            "formatted_default_state_types": formatted_default_state_types
-        }
-
     def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
         """Get the event types to include in the prejoin state
 
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 16f93273b3..00182090b2 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -35,20 +35,6 @@ class AppServiceConfig(Config):
         self.app_service_config_files = config.get("app_service_config_files", [])
         self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
 
-    def generate_config_section(cls, **kwargs: Any) -> str:
-        return """\
-        # A list of application service config files to use
-        #
-        #app_service_config_files:
-        #  - app_service_1.yaml
-        #  - app_service_2.yaml
-
-        # Uncomment to enable tracking of application service IP addresses. Implicitly
-        # enables MAU tracking for application service users.
-        #
-        #track_appservice_user_ips: true
-        """
-
 
 def load_appservices(
     hostname: str, config_files: List[str]
diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index 265a554a5d..35774962c0 100644
--- a/synapse/config/auth.py
+++ b/synapse/config/auth.py
@@ -53,78 +53,3 @@ class AuthConfig(Config):
         self.ui_auth_session_timeout = self.parse_duration(
             ui_auth.get("session_timeout", 0)
         )
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        password_config:
-           # Uncomment to disable password login.
-           # Set to `only_for_reauth` to permit reauthentication for users that
-           # have passwords and are already logged in.
-           #
-           #enabled: false
-
-           # Uncomment to disable authentication against the local password
-           # database. This is ignored if `enabled` is false, and is only useful
-           # if you have other password_providers.
-           #
-           #localdb_enabled: false
-
-           # Uncomment and change to a secret random string for extra security.
-           # DO NOT CHANGE THIS AFTER INITIAL SETUP!
-           #
-           #pepper: "EVEN_MORE_SECRET"
-
-           # Define and enforce a password policy. Each parameter is optional.
-           # This is an implementation of MSC2000.
-           #
-           policy:
-              # Whether to enforce the password policy.
-              # Defaults to 'false'.
-              #
-              #enabled: true
-
-              # Minimum accepted length for a password.
-              # Defaults to 0.
-              #
-              #minimum_length: 15
-
-              # Whether a password must contain at least one digit.
-              # Defaults to 'false'.
-              #
-              #require_digit: true
-
-              # Whether a password must contain at least one symbol.
-              # A symbol is any character that's not a number or a letter.
-              # Defaults to 'false'.
-              #
-              #require_symbol: true
-
-              # Whether a password must contain at least one lowercase letter.
-              # Defaults to 'false'.
-              #
-              #require_lowercase: true
-
-              # Whether a password must contain at least one uppercase letter.
-              # Defaults to 'false'.
-              #
-              #require_uppercase: true
-
-        ui_auth:
-            # The amount of time to allow a user-interactive authentication session
-            # to be active.
-            #
-            # This defaults to 0, meaning the user is queried for their credentials
-            # before every action, but this can be overridden to allow a single
-            # validation to be re-used.  This weakens the protections afforded by
-            # the user-interactive authentication process, by allowing for multiple
-            # (and potentially different) operations to use the same validation session.
-            #
-            # This is ignored for potentially "dangerous" operations (including
-            # deactivating an account, modifying an account password, and
-            # adding a 3PID).
-            #
-            # Uncomment below to allow for credential validation to last for 15
-            # seconds.
-            #
-            #session_timeout: "15s"
-        """
diff --git a/synapse/config/background_updates.py b/synapse/config/background_updates.py
index 07fadbe041..1c6cd97de8 100644
--- a/synapse/config/background_updates.py
+++ b/synapse/config/background_updates.py
@@ -21,40 +21,6 @@ from ._base import Config
 class BackgroundUpdateConfig(Config):
     section = "background_updates"
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Background Updates ##
-
-        # Background updates are database updates that are run in the background in batches.
-        # The duration, minimum batch size, default batch size, whether to sleep between batches and if so, how long to
-        # sleep can all be configured. This is helpful to speed up or slow down the updates.
-        #
-        background_updates:
-            # How long in milliseconds to run a batch of background updates for. Defaults to 100. Uncomment and set
-            # a time to change the default.
-            #
-            #background_update_duration_ms: 500
-
-            # Whether to sleep between updates. Defaults to True. Uncomment to change the default.
-            #
-            #sleep_enabled: false
-
-            # If sleeping between updates, how long in milliseconds to sleep for. Defaults to 1000. Uncomment
-            # and set a duration to change the default.
-            #
-            #sleep_duration_ms: 300
-
-            # Minimum size a batch of background updates can be. Must be greater than 0. Defaults to 1. Uncomment and
-            # set a size to change the default.
-            #
-            #min_batch_size: 10
-
-            # The batch size to use for the first iteration of a new background update. The default is 100.
-            # Uncomment and set a size to change the default.
-            #
-            #default_batch_size: 50
-        """
-
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         bg_update_config = config.get("background_updates") or {}
 
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index d2f55534d7..2db8cfb005 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, Optional
 import attr
 
 from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
 
 from ._base import Config, ConfigError
 
@@ -113,97 +113,6 @@ class CacheConfig(Config):
         with _CACHES_LOCK:
             _CACHES.clear()
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Caching ##
-
-        # Caching can be configured through the following options.
-        #
-        # A cache 'factor' is a multiplier that can be applied to each of
-        # Synapse's caches in order to increase or decrease the maximum
-        # number of entries that can be stored.
-        #
-        # The configuration for cache factors (caches.global_factor and
-        # caches.per_cache_factors) can be reloaded while the application is running,
-        # by sending a SIGHUP signal to the Synapse process. Changes to other parts of
-        # the caching config will NOT be applied after a SIGHUP is received; a restart
-        # is necessary.
-
-        # The number of events to cache in memory. Not affected by
-        # caches.global_factor.
-        #
-        #event_cache_size: 10K
-
-        caches:
-          # Controls the global cache factor, which is the default cache factor
-          # for all caches if a specific factor for that cache is not otherwise
-          # set.
-          #
-          # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment
-          # variable. Setting by environment variable takes priority over
-          # setting through the config file.
-          #
-          # Defaults to 0.5, which will half the size of all caches.
-          #
-          #global_factor: 1.0
-
-          # A dictionary of cache name to cache factor for that individual
-          # cache. Overrides the global cache factor for a given cache.
-          #
-          # These can also be set through environment variables comprised
-          # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital
-          # letters and underscores. Setting by environment variable
-          # takes priority over setting through the config file.
-          # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0
-          #
-          # Some caches have '*' and other characters that are not
-          # alphanumeric or underscores. These caches can be named with or
-          # without the special characters stripped. For example, to specify
-          # the cache factor for `*stateGroupCache*` via an environment
-          # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`.
-          #
-          per_cache_factors:
-            #get_users_who_share_room_with_user: 2.0
-
-          # Controls whether cache entries are evicted after a specified time
-          # period. Defaults to true. Uncomment to disable this feature.
-          #
-          #expire_caches: false
-
-          # If expire_caches is enabled, this flag controls how long an entry can
-          # be in a cache without having been accessed before being evicted.
-          # Defaults to 30m. Uncomment to set a different time to live for cache entries.
-          #
-          #cache_entry_ttl: 30m
-
-          # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
-          # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
-          # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
-          # this option, and all three of the options must be specified for this feature to work.
-          #cache_autotuning:
-            # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
-            # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
-            # the flag below, or until the `min_cache_ttl` is hit.
-            #max_cache_memory_usage: 1024M
-
-            # This flag sets a rough target for the desired memory usage of the caches.
-            #target_cache_memory_usage: 758M
-
-            # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
-            # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
-            # from being emptied while Synapse is evicting due to memory.
-            #min_cache_ttl: 5m
-
-          # Controls how long the results of a /sync request are cached for after
-          # a successful response is returned. A higher duration can help clients with
-          # intermittent connections, at the cost of higher memory usage.
-          #
-          # By default, this is zero, which means that sync responses are not cached
-          # at all.
-          #
-          #sync_response_cache_duration: 2m
-        """
-
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         """Populate this config object with values from `config`.
 
@@ -250,12 +159,7 @@ class CacheConfig(Config):
 
         self.track_memory_usage = cache_config.get("track_memory_usage", False)
         if self.track_memory_usage:
-            try:
-                check_requirements("cache_memory")
-            except DependencyException as e:
-                raise ConfigError(
-                    e.message  # noqa: B306, DependencyException.message is a property
-                )
+            check_requirements("cache_memory")
 
         expire_caches = cache_config.get("expire_caches", True)
         cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m")
@@ -297,7 +201,7 @@ class CacheConfig(Config):
             self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
 
         self.sync_response_cache_duration = self.parse_duration(
-            cache_config.get("sync_response_cache_duration", 0)
+            cache_config.get("sync_response_cache_duration", "2m")
         )
 
     def resize_all_caches(self) -> None:
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 92c603f224..1737d5e327 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -45,30 +45,3 @@ class CaptchaConfig(Config):
             "https://www.recaptcha.net/recaptcha/api/siteverify",
         )
         self.recaptcha_template = self.read_template("recaptcha.html")
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Captcha ##
-        # See docs/CAPTCHA_SETUP.md for full details of configuring this.
-
-        # This homeserver's ReCAPTCHA public key. Must be specified if
-        # enable_registration_captcha is enabled.
-        #
-        #recaptcha_public_key: "YOUR_PUBLIC_KEY"
-
-        # This homeserver's ReCAPTCHA private key. Must be specified if
-        # enable_registration_captcha is enabled.
-        #
-        #recaptcha_private_key: "YOUR_PRIVATE_KEY"
-
-        # Uncomment to enable ReCaptcha checks when registering, preventing signup
-        # unless a captcha is answered. Requires a valid ReCaptcha
-        # public/private key. Defaults to 'false'.
-        #
-        #enable_registration_captcha: true
-
-        # The API endpoint to use for verifying m.login.recaptcha responses.
-        # Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify".
-        #
-        #recaptcha_siteverify_api: "https://my.recaptcha.site"
-        """
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 8af0794ba4..9152c06bd6 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -53,37 +53,6 @@ class CasConfig(Config):
             self.cas_displayname_attribute = None
             self.cas_required_attributes = []
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # Enable Central Authentication Service (CAS) for registration and login.
-        #
-        cas_config:
-          # Uncomment the following to enable authorization against a CAS server.
-          # Defaults to false.
-          #
-          #enabled: true
-
-          # The URL of the CAS authorization endpoint.
-          #
-          #server_url: "https://cas-server.com"
-
-          # The attribute of the CAS response to use as the display name.
-          #
-          # If unset, no displayname will be set.
-          #
-          #displayname_attribute: name
-
-          # It is possible to configure Synapse to only allow logins if CAS attributes
-          # match particular values. All of the keys in the mapping below must exist
-          # and the values must match the given value. Alternately if the given value
-          # is None then any value is allowed (the attribute just must exist).
-          # All of the listed attributes must match for the login to be permitted.
-          #
-          #required_attributes:
-          #  userGroup: "staff"
-          #  department: None
-        """
-
 
 # CAS uses a legacy required attributes mapping, not the one provided by
 # SsoAttributeRequirement.
diff --git a/synapse/config/consent.py b/synapse/config/consent.py
index 8ee3d34521..be74609dc4 100644
--- a/synapse/config/consent.py
+++ b/synapse/config/consent.py
@@ -20,58 +20,6 @@ from synapse.types import JsonDict
 
 from ._base import Config
 
-DEFAULT_CONFIG = """\
-# User Consent configuration
-#
-# for detailed instructions, see
-# https://matrix-org.github.io/synapse/latest/consent_tracking.html
-#
-# Parts of this section are required if enabling the 'consent' resource under
-# 'listeners', in particular 'template_dir' and 'version'.
-#
-# 'template_dir' gives the location of the templates for the HTML forms.
-# This directory should contain one subdirectory per language (eg, 'en', 'fr'),
-# and each language directory should contain the policy document (named as
-# '<version>.html') and a success page (success.html).
-#
-# 'version' specifies the 'current' version of the policy document. It defines
-# the version to be served by the consent resource if there is no 'v'
-# parameter.
-#
-# 'server_notice_content', if enabled, will send a user a "Server Notice"
-# asking them to consent to the privacy policy. The 'server_notices' section
-# must also be configured for this to work. Notices will *not* be sent to
-# guest users unless 'send_server_notice_to_guests' is set to true.
-#
-# 'block_events_error', if set, will block any attempts to send events
-# until the user consents to the privacy policy. The value of the setting is
-# used as the text of the error.
-#
-# 'require_at_registration', if enabled, will add a step to the registration
-# process, similar to how captcha works. Users will be required to accept the
-# policy before their account is created.
-#
-# 'policy_name' is the display name of the policy users will see when registering
-# for an account. Has no effect unless `require_at_registration` is enabled.
-# Defaults to "Privacy Policy".
-#
-#user_consent:
-#  template_dir: res/templates/privacy
-#  version: 1.0
-#  server_notice_content:
-#    msgtype: m.text
-#    body: >-
-#      To continue using this homeserver you must review and agree to the
-#      terms and conditions at %(consent_uri)s
-#  send_server_notice_to_guests: true
-#  block_events_error: >-
-#    To continue using this homeserver you must review and agree to the
-#    terms and conditions at %(consent_uri)s
-#  require_at_registration: false
-#  policy_name: Privacy Policy
-#
-"""
-
 
 class ConsentConfig(Config):
 
@@ -118,6 +66,3 @@ class ConsentConfig(Config):
         self.user_consent_policy_name = consent_config.get(
             "policy_name", "Privacy Policy"
         )
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return DEFAULT_CONFIG
diff --git a/synapse/config/database.py b/synapse/config/database.py
index de0d3ca0f0..928fec8dfe 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -28,56 +28,6 @@ Ignoring 'database_path' setting: not using a sqlite3 database.
 """
 
 DEFAULT_CONFIG = """\
-## Database ##
-
-# The 'database' setting defines the database that synapse uses to store all of
-# its data.
-#
-# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
-# 'psycopg2' (for PostgreSQL).
-#
-# 'txn_limit' gives the maximum number of transactions to run per connection
-# before reconnecting. Defaults to 0, which means no limit.
-#
-# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
-# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
-# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
-# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
-# https://wiki.postgresql.org/wiki/Locale_data_changes
-#
-# 'args' gives options which are passed through to the database engine,
-# except for options starting 'cp_', which are used to configure the Twisted
-# connection pool. For a reference to valid arguments, see:
-#   * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
-#   * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
-#   * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
-#
-#
-# Example SQLite configuration:
-#
-#database:
-#  name: sqlite3
-#  args:
-#    database: /path/to/homeserver.db
-#
-#
-# Example Postgres configuration:
-#
-#database:
-#  name: psycopg2
-#  txn_limit: 10000
-#  args:
-#    user: synapse_user
-#    password: secretpassword
-#    database: synapse
-#    host: localhost
-#    port: 5432
-#    cp_min: 5
-#    cp_max: 10
-#
-# For more information on using Synapse with Postgres,
-# see https://matrix-org.github.io/synapse/latest/postgres.html.
-#
 database:
   name: sqlite3
   args:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 5b5c2f4fff..a3af35b7c4 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -18,7 +18,6 @@
 import email.utils
 import logging
 import os
-from enum import Enum
 from typing import Any
 
 import attr
@@ -53,7 +52,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'email' section. Support for this setting has been deprecated and will be removed in a
 future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
@@ -86,14 +85,19 @@ class EmailConfig(Config):
         if email_config is None:
             email_config = {}
 
+        self.force_tls = email_config.get("force_tls", False)
         self.email_smtp_host = email_config.get("smtp_host", "localhost")
-        self.email_smtp_port = email_config.get("smtp_port", 25)
+        self.email_smtp_port = email_config.get(
+            "smtp_port", 465 if self.force_tls else 25
+        )
         self.email_smtp_user = email_config.get("smtp_user", None)
         self.email_smtp_pass = email_config.get("smtp_pass", None)
         self.require_transport_security = email_config.get(
             "require_transport_security", False
         )
         self.enable_smtp_tls = email_config.get("enable_tls", True)
+        if self.force_tls and not self.enable_smtp_tls:
+            raise ConfigError("email.force_tls requires email.enable_tls to be true")
         if self.require_transport_security and not self.enable_smtp_tls:
             raise ConfigError(
                 "email.require_transport_security requires email.enable_tls to be true"
@@ -131,41 +135,22 @@ class EmailConfig(Config):
 
         self.email_enable_notifs = email_config.get("enable_notifs", False)
 
-        self.threepid_behaviour_email = (
-            # Have Synapse handle the email sending if account_threepid_delegates.email
-            # is not defined
-            # msisdn is currently always remote while Synapse does not support any method of
-            # sending SMS messages
-            ThreepidBehaviour.REMOTE
-            if self.root.registration.account_threepid_delegate_email
-            else ThreepidBehaviour.LOCAL
-        )
-
         if config.get("trust_identity_server_for_password_resets"):
             raise ConfigError(
                 'The config option "trust_identity_server_for_password_resets" '
-                'has been replaced by "account_threepid_delegate". '
-                "Please consult the sample config at docs/sample_config.yaml for "
-                "details and update your config file."
+                "is no longer supported. Please remove it from the config file."
             )
 
-        self.local_threepid_handling_disabled_due_to_email_config = False
-        if (
-            self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            and email_config == {}
-        ):
-            # We cannot warn the user this has happened here
-            # Instead do so when a user attempts to reset their password
-            self.local_threepid_handling_disabled_due_to_email_config = True
-
-            self.threepid_behaviour_email = ThreepidBehaviour.OFF
+        # If we have email config settings, assume that we can verify ownership of
+        # email addresses.
+        self.can_verify_email = email_config != {}
 
         # Get lifetime of a validation token in milliseconds
         self.email_validation_token_lifetime = self.parse_duration(
             email_config.get("validation_token_lifetime", "1h")
         )
 
-        if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.can_verify_email:
             missing = []
             if not self.email_notif_from:
                 missing.append("email.notif_from")
@@ -356,172 +341,3 @@ class EmailConfig(Config):
                     "Config option email.invite_client_location must be a http or https URL",
                     path=("email", "invite_client_location"),
                 )
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return (
-            """\
-        # Configuration for sending emails from Synapse.
-        #
-        # Server admins can configure custom templates for email content. See
-        # https://matrix-org.github.io/synapse/latest/templates.html for more information.
-        #
-        email:
-          # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'.
-          #
-          #smtp_host: mail.server
-
-          # The port on the mail server for outgoing SMTP. Defaults to 25.
-          #
-          #smtp_port: 587
-
-          # Username/password for authentication to the SMTP server. By default, no
-          # authentication is attempted.
-          #
-          #smtp_user: "exampleusername"
-          #smtp_pass: "examplepassword"
-
-          # Uncomment the following to require TLS transport security for SMTP.
-          # By default, Synapse will connect over plain text, and will then switch to
-          # TLS via STARTTLS *if the SMTP server supports it*. If this option is set,
-          # Synapse will refuse to connect unless the server supports STARTTLS.
-          #
-          #require_transport_security: true
-
-          # Uncomment the following to disable TLS for SMTP.
-          #
-          # By default, if the server supports TLS, it will be used, and the server
-          # must present a certificate that is valid for 'smtp_host'. If this option
-          # is set to false, TLS will not be used.
-          #
-          #enable_tls: false
-
-          # notif_from defines the "From" address to use when sending emails.
-          # It must be set if email sending is enabled.
-          #
-          # The placeholder '%%(app)s' will be replaced by the application name,
-          # which is normally 'app_name' (below), but may be overridden by the
-          # Matrix client application.
-          #
-          # Note that the placeholder must be written '%%(app)s', including the
-          # trailing 's'.
-          #
-          #notif_from: "Your Friendly %%(app)s homeserver <noreply@example.com>"
-
-          # app_name defines the default value for '%%(app)s' in notif_from and email
-          # subjects. It defaults to 'Matrix'.
-          #
-          #app_name: my_branded_matrix_server
-
-          # Uncomment the following to enable sending emails for messages that the user
-          # has missed. Disabled by default.
-          #
-          #enable_notifs: true
-
-          # Uncomment the following to disable automatic subscription to email
-          # notifications for new users. Enabled by default.
-          #
-          #notif_for_new_users: false
-
-          # Custom URL for client links within the email notifications. By default
-          # links will be based on "https://matrix.to".
-          #
-          # (This setting used to be called riot_base_url; the old name is still
-          # supported for backwards-compatibility but is now deprecated.)
-          #
-          #client_base_url: "http://localhost/riot"
-
-          # Configure the time that a validation email will expire after sending.
-          # Defaults to 1h.
-          #
-          #validation_token_lifetime: 15m
-
-          # The web client location to direct users to during an invite. This is passed
-          # to the identity server as the org.matrix.web_client_location key. Defaults
-          # to unset, giving no guidance to the identity server.
-          #
-          #invite_client_location: https://app.element.io
-
-          # Subjects to use when sending emails from Synapse.
-          #
-          # The placeholder '%%(app)s' will be replaced with the value of the 'app_name'
-          # setting above, or by a value dictated by the Matrix client application.
-          #
-          # If a subject isn't overridden in this configuration file, the value used as
-          # its example will be used.
-          #
-          #subjects:
-
-            # Subjects for notification emails.
-            #
-            # On top of the '%%(app)s' placeholder, these can use the following
-            # placeholders:
-            #
-            #   * '%%(person)s', which will be replaced by the display name of the user(s)
-            #      that sent the message(s), e.g. "Alice and Bob".
-            #   * '%%(room)s', which will be replaced by the name of the room the
-            #      message(s) have been sent to, e.g. "My super room".
-            #
-            # See the example provided for each setting to see which placeholder can be
-            # used and how to use them.
-            #
-            # Subject to use to notify about one message from one or more user(s) in a
-            # room which has a name.
-            #message_from_person_in_room: "%(message_from_person_in_room)s"
-            #
-            # Subject to use to notify about one message from one or more user(s) in a
-            # room which doesn't have a name.
-            #message_from_person: "%(message_from_person)s"
-            #
-            # Subject to use to notify about multiple messages from one or more users in
-            # a room which doesn't have a name.
-            #messages_from_person: "%(messages_from_person)s"
-            #
-            # Subject to use to notify about multiple messages in a room which has a
-            # name.
-            #messages_in_room: "%(messages_in_room)s"
-            #
-            # Subject to use to notify about multiple messages in multiple rooms.
-            #messages_in_room_and_others: "%(messages_in_room_and_others)s"
-            #
-            # Subject to use to notify about multiple messages from multiple persons in
-            # multiple rooms. This is similar to the setting above except it's used when
-            # the room in which the notification was triggered has no name.
-            #messages_from_person_and_others: "%(messages_from_person_and_others)s"
-            #
-            # Subject to use to notify about an invite to a room which has a name.
-            #invite_from_person_to_room: "%(invite_from_person_to_room)s"
-            #
-            # Subject to use to notify about an invite to a room which doesn't have a
-            # name.
-            #invite_from_person: "%(invite_from_person)s"
-
-            # Subject for emails related to account administration.
-            #
-            # On top of the '%%(app)s' placeholder, these one can use the
-            # '%%(server_name)s' placeholder, which will be replaced by the value of the
-            # 'server_name' setting in your Synapse configuration.
-            #
-            # Subject to use when sending a password reset email.
-            #password_reset: "%(password_reset)s"
-            #
-            # Subject to use when sending a verification email to assert an address's
-            # ownership.
-            #email_validation: "%(email_validation)s"
-        """
-            % DEFAULT_SUBJECTS
-        )
-
-
-class ThreepidBehaviour(Enum):
-    """
-    Enum to define the behaviour of Synapse with regards to when it contacts an identity
-    server for 3pid registration and password resets
-
-    REMOTE = use an external server to send tokens
-    LOCAL = send tokens ourselves
-    OFF = disable registration via 3pid and password resets
-    """
-
-    REMOTE = "remote"
-    LOCAL = "local"
-    OFF = "off"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index f2dfd49b07..702b81e636 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,9 +32,6 @@ class ExperimentalConfig(Config):
         # MSC2716 (importing historical messages)
         self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
 
-        # MSC2285 (private read receipts)
-        self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
-
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
@@ -74,6 +71,9 @@ class ExperimentalConfig(Config):
         self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
 
         # MSC2654: Unread counts
+        #
+        # Note that enabling this will result in an incorrect unread count for
+        # previously calculated push actions.
         self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
 
         # MSC2815 (allow room moderators to view redacted event content)
@@ -84,3 +84,12 @@ class ExperimentalConfig(Config):
 
         # MSC3772: A push rule for mutual relations.
         self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
+
+        # MSC3715: dir param on /relations.
+        self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
+
+        # MSC3848: Introduce errcodes for specific event sending failures
+        self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+        # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
+        self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index f83f93c0ef..336fca578a 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -49,44 +49,5 @@ class FederationConfig(Config):
             "allow_device_name_lookup_over_federation", False
         )
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Federation ##
-
-        # Restrict federation to the following whitelist of domains.
-        # N.B. we recommend also firewalling your federation listener to limit
-        # inbound federation traffic as early as possible, rather than relying
-        # purely on this application-layer restriction.  If not specified, the
-        # default is to whitelist everything.
-        #
-        #federation_domain_whitelist:
-        #  - lon.example.com
-        #  - nyc.example.com
-        #  - syd.example.com
-
-        # Report prometheus metrics on the age of PDUs being sent to and received from
-        # the following domains. This can be used to give an idea of "delay" on inbound
-        # and outbound federation, though be aware that any delay can be due to problems
-        # at either end or with the intermediate network.
-        #
-        # By default, no domains are monitored in this way.
-        #
-        #federation_metrics_domains:
-        #  - matrix.org
-        #  - example.com
-
-        # Uncomment to disable profile lookup over federation. By default, the
-        # Federation API allows other homeservers to obtain profile data of any user
-        # on this homeserver. Defaults to 'true'.
-        #
-        #allow_profile_lookup_over_federation: false
-
-        # Uncomment to allow device display name lookup over federation. By default, the
-        # Federation API prevents other homeservers from obtaining the display names of
-        # user devices on this homeserver. Defaults to 'false'.
-        #
-        #allow_device_name_lookup_over_federation: true
-        """
-
 
 _METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}}
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/config/groups.py
index 3826b87dec..baa051fdd4 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/config/groups.py
@@ -1,5 +1,4 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,10 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from typing import Any
 
-from ._base import BaseSlavedStore
+from synapse.types import JsonDict
 
+from ._base import Config
 
-class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    pass
+
+class GroupsConfig(Config):
+    section = "groups"
+
+    def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        self.enable_group_creation = config.get("enable_group_creation", False)
+        self.group_creation_prefix = config.get("group_creation_prefix", "")
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 2a756d1a7c..a973bb5080 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -15,14 +15,9 @@
 from typing import Any
 
 from synapse.types import JsonDict
+from synapse.util.check_dependencies import check_requirements
 
-from ._base import Config, ConfigError
-
-MISSING_JWT = """Missing jwt library. This is required for jwt login.
-
-    Install by running:
-        pip install pyjwt
-    """
+from ._base import Config
 
 
 class JWTConfig(Config):
@@ -41,13 +36,7 @@ class JWTConfig(Config):
             # that the claims exist on the JWT.
             self.jwt_issuer = jwt_config.get("issuer")
             self.jwt_audiences = jwt_config.get("audiences")
-
-            try:
-                import jwt
-
-                jwt  # To stop unused lint.
-            except ImportError:
-                raise ConfigError(MISSING_JWT)
+            check_requirements("jwt")
         else:
             self.jwt_enabled = False
             self.jwt_secret = None
@@ -55,67 +44,3 @@ class JWTConfig(Config):
             self.jwt_subject_claim = None
             self.jwt_issuer = None
             self.jwt_audiences = None
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # JSON web token integration. The following settings can be used to make
-        # Synapse JSON web tokens for authentication, instead of its internal
-        # password database.
-        #
-        # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
-        # used as the localpart of the mxid.
-        #
-        # Additionally, the expiration time ("exp"), not before time ("nbf"),
-        # and issued at ("iat") claims are validated if present.
-        #
-        # Note that this is a non-standard login type and client support is
-        # expected to be non-existent.
-        #
-        # See https://matrix-org.github.io/synapse/latest/jwt.html.
-        #
-        #jwt_config:
-            # Uncomment the following to enable authorization using JSON web
-            # tokens. Defaults to false.
-            #
-            #enabled: true
-
-            # This is either the private shared secret or the public key used to
-            # decode the contents of the JSON web token.
-            #
-            # Required if 'enabled' is true.
-            #
-            #secret: "provided-by-your-issuer"
-
-            # The algorithm used to sign the JSON web token.
-            #
-            # Supported algorithms are listed at
-            # https://pyjwt.readthedocs.io/en/latest/algorithms.html
-            #
-            # Required if 'enabled' is true.
-            #
-            #algorithm: "provided-by-your-issuer"
-
-            # Name of the claim containing a unique identifier for the user.
-            #
-            # Optional, defaults to `sub`.
-            #
-            #subject_claim: "sub"
-
-            # The issuer to validate the "iss" claim against.
-            #
-            # Optional, if provided the "iss" claim will be required and
-            # validated for all JSON web tokens.
-            #
-            #issuer: "provided-by-your-issuer"
-
-            # A list of audiences to validate the "aud" claim against.
-            #
-            # Optional, if provided the "aud" claim will be required and
-            # validated for all JSON web tokens.
-            #
-            # Note that if the "aud" claim is included in a JSON web token then
-            # validation will fail without configuring audiences.
-            #
-            #audiences:
-            #    - "provided-by-your-issuer"
-        """
diff --git a/synapse/config/key.py b/synapse/config/key.py
index ada65f6dd6..cc75efdf8f 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -159,16 +159,18 @@ class KeyConfig(Config):
             )
         )
 
-        self.macaroon_secret_key = config.get(
+        macaroon_secret_key: Optional[str] = config.get(
             "macaroon_secret_key", self.root.registration.registration_shared_secret
         )
 
-        if not self.macaroon_secret_key:
+        if not macaroon_secret_key:
             # Unfortunately, there are people out there that don't have this
             # set. Lets just be "nice" and derive one from their secret key.
             logger.warning("Config is missing macaroon_secret_key")
             seed = bytes(self.signing_key[0])
             self.macaroon_secret_key = hashlib.sha256(seed).digest()
+        else:
+            self.macaroon_secret_key = macaroon_secret_key.encode("utf-8")
 
         # a secret which is used to calculate HMACs for form values, to stop
         # falsification of values
@@ -182,111 +184,22 @@ class KeyConfig(Config):
         **kwargs: Any,
     ) -> str:
         base_key_name = os.path.join(config_dir_path, server_name)
+        macaroon_secret_key = ""
+        form_secret = ""
 
         if generate_secrets:
             macaroon_secret_key = 'macaroon_secret_key: "%s"' % (
                 random_string_with_symbols(50),
             )
             form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
-        else:
-            macaroon_secret_key = "#macaroon_secret_key: <PRIVATE STRING>"
-            form_secret = "#form_secret: <PRIVATE STRING>"
 
         return (
             """\
-        # a secret which is used to sign access tokens. If none is specified,
-        # the registration_shared_secret is used, if one is given; otherwise,
-        # a secret key is derived from the signing key.
-        #
         %(macaroon_secret_key)s
-
-        # a secret which is used to calculate HMACs for form values, to stop
-        # falsification of values. Must be specified for the User Consent
-        # forms to work.
-        #
         %(form_secret)s
-
-        ## Signing Keys ##
-
-        # Path to the signing key to sign messages with
-        #
         signing_key_path: "%(base_key_name)s.signing.key"
-
-        # The keys that the server used to sign messages with but won't use
-        # to sign new messages.
-        #
-        old_signing_keys:
-          # For each key, `key` should be the base64-encoded public key, and
-          # `expired_ts`should be the time (in milliseconds since the unix epoch) that
-          # it was last used.
-          #
-          # It is possible to build an entry from an old signing.key file using the
-          # `export_signing_key` script which is provided with synapse.
-          #
-          # For example:
-          #
-          #"ed25519:id": { key: "base64string", expired_ts: 123456789123 }
-
-        # How long key response published by this server is valid for.
-        # Used to set the valid_until_ts in /key/v2 APIs.
-        # Determines how quickly servers will query to check which keys
-        # are still valid.
-        #
-        #key_refresh_interval: 1d
-
-        # The trusted servers to download signing keys from.
-        #
-        # When we need to fetch a signing key, each server is tried in parallel.
-        #
-        # Normally, the connection to the key server is validated via TLS certificates.
-        # Additional security can be provided by configuring a `verify key`, which
-        # will make synapse check that the response is signed by that key.
-        #
-        # This setting supercedes an older setting named `perspectives`. The old format
-        # is still supported for backwards-compatibility, but it is deprecated.
-        #
-        # 'trusted_key_servers' defaults to matrix.org, but using it will generate a
-        # warning on start-up. To suppress this warning, set
-        # 'suppress_key_server_warning' to true.
-        #
-        # Options for each entry in the list include:
-        #
-        #    server_name: the name of the server. required.
-        #
-        #    verify_keys: an optional map from key id to base64-encoded public key.
-        #       If specified, we will check that the response is signed by at least
-        #       one of the given keys.
-        #
-        #    accept_keys_insecurely: a boolean. Normally, if `verify_keys` is unset,
-        #       and federation_verify_certificates is not `true`, synapse will refuse
-        #       to start, because this would allow anyone who can spoof DNS responses
-        #       to masquerade as the trusted key server. If you know what you are doing
-        #       and are sure that your network environment provides a secure connection
-        #       to the key server, you can set this to `true` to override this
-        #       behaviour.
-        #
-        # An example configuration might look like:
-        #
-        #trusted_key_servers:
-        #  - server_name: "my_trusted_server.example.com"
-        #    verify_keys:
-        #      "ed25519:auto": "abcdefghijklmnopqrstuvwxyzabcdefghijklmopqr"
-        #  - server_name: "my_other_trusted_server.example.com"
-        #
         trusted_key_servers:
           - server_name: "matrix.org"
-
-        # Uncomment the following to disable the warning that is emitted when the
-        # trusted_key_servers include 'matrix.org'. See above.
-        #
-        #suppress_key_server_warning: true
-
-        # The signing keys to use when acting as a trusted key server. If not specified
-        # defaults to the server signing key.
-        #
-        # Can contain multiple keys, one per line.
-        #
-        #key_server_signing_keys_path: "key_server_signing_keys.key"
         """
             % locals()
         )
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 82a5b5fa12..6c1f78f8df 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -153,11 +153,6 @@ class LoggingConfig(Config):
         log_config = os.path.join(config_dir_path, server_name + ".log.config")
         return (
             """\
-        ## Logging ##
-
-        # A yaml python logging config file as described by
-        # https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
-        #
         log_config: "%(log_config)s"
         """
             % locals()
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index aa360a417c..f3134834e5 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -18,7 +18,7 @@ from typing import Any, Optional
 import attr
 
 from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
 
 from ._base import Config, ConfigError
 
@@ -42,6 +42,35 @@ class MetricsConfig(Config):
 
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         self.enable_metrics = config.get("enable_metrics", False)
+
+        """
+        ### `enable_legacy_metrics` (experimental)
+
+        **Experimental: this option may be removed or have its behaviour
+        changed at any time, with no notice.**
+
+        Set to `true` to publish both legacy and non-legacy Prometheus metric names,
+        or to `false` to only publish non-legacy Prometheus metric names.
+        Defaults to `true`. Has no effect if `enable_metrics` is `false`.
+
+        Legacy metric names include:
+        - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules;
+        - counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard.
+
+        These legacy metric names are unconventional and not compliant with OpenMetrics standards.
+        They are included for backwards compatibility.
+
+        Example configuration:
+        ```yaml
+        enable_legacy_metrics: false
+        ```
+
+        See https://github.com/matrix-org/synapse/issues/11106 for context.
+
+        *Since v1.67.0.*
+        """
+        self.enable_legacy_metrics = config.get("enable_legacy_metrics", True)
+
         self.report_stats = config.get("report_stats", None)
         self.report_stats_endpoint = config.get(
             "report_stats_endpoint", "https://matrix.org/report-usage-stats/push"
@@ -57,12 +86,7 @@ class MetricsConfig(Config):
 
         self.sentry_enabled = "sentry" in config
         if self.sentry_enabled:
-            try:
-                check_requirements("sentry")
-            except DependencyException as e:
-                raise ConfigError(
-                    e.message  # noqa: B306, DependencyException.message is a property
-                )
+            check_requirements("sentry")
 
             self.sentry_dsn = config["sentry"].get("dsn")
             if not self.sentry_dsn:
@@ -73,46 +97,8 @@ class MetricsConfig(Config):
     def generate_config_section(
         self, report_stats: Optional[bool] = None, **kwargs: Any
     ) -> str:
-        res = """\
-        ## Metrics ###
-
-        # Enable collection and rendering of performance metrics
-        #
-        #enable_metrics: false
-
-        # Enable sentry integration
-        # NOTE: While attempts are made to ensure that the logs don't contain
-        # any sensitive information, this cannot be guaranteed. By enabling
-        # this option the sentry server may therefore receive sensitive
-        # information, and it in turn may then diseminate sensitive information
-        # through insecure notification channels if so configured.
-        #
-        #sentry:
-        #    dsn: "..."
-
-        # Flags to enable Prometheus metrics which are not suitable to be
-        # enabled by default, either for performance reasons or limited use.
-        #
-        metrics_flags:
-            # Publish synapse_federation_known_servers, a gauge of the number of
-            # servers this homeserver knows about, including itself. May cause
-            # performance problems on large homeservers.
-            #
-            #known_servers: true
-
-        # Whether or not to report anonymized homeserver usage statistics.
-        #
-        """
-
-        if report_stats is None:
-            res += "#report_stats: true|false\n"
+        if report_stats is not None:
+            res = "report_stats: %s\n" % ("true" if report_stats else "false")
         else:
-            res += "report_stats: %s\n" % ("true" if report_stats else "false")
-
-        res += """
-        # The endpoint to report the anonymized homeserver usage statistics to.
-        # Defaults to https://matrix.org/report-usage-stats/push
-        #
-        #report_stats_endpoint: https://example.com/report-usage-stats/push
-        """
+            res = "\n"
         return res
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index 0915014f7d..903637be8e 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -31,20 +31,3 @@ class ModulesConfig(Config):
                 raise ConfigError("expected a mapping", config_path)
 
             self.loaded_modules.append(load_module(module, config_path))
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """
-            ## Modules ##
-
-            # Server admins can expand Synapse's functionality with external modules.
-            #
-            # See https://matrix-org.github.io/synapse/latest/modules/index.html for more
-            # documentation on how to configure or create custom modules for Synapse.
-            #
-            modules:
-              #- module: my_super_module.MySuperClass
-              #  config:
-              #    do_thing: true
-              #- module: my_other_super_module.SomeClass
-              #  config: {}
-            """
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index e9edea0731..0d32aba70a 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -143,29 +143,6 @@ class OembedConfig(Config):
         )
         return re.compile(pattern)
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # oEmbed allows for easier embedding content from a website. It can be
-        # used for generating URLs previews of services which support it.
-        #
-        oembed:
-          # A default list of oEmbed providers is included with Synapse.
-          #
-          # Uncomment the following to disable using these default oEmbed URLs.
-          # Defaults to 'false'.
-          #
-          #disable_default_providers: true
-
-          # Additional files with oEmbed configuration (each should be in the
-          # form of providers.json).
-          #
-          # By default, this list is empty (so only the default providers.json
-          # is used).
-          #
-          #additional_providers:
-          #  - oembed/my_providers.json
-        """
-
 
 _OEMBED_PROVIDER_SCHEMA = {
     "type": "array",
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index b9c40522d8..5418a332da 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
-from ..util.check_dependencies import DependencyException, check_requirements
+from ..util.check_dependencies import check_requirements
 from ._base import Config, ConfigError, read_file
 
 DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider"
@@ -41,12 +41,7 @@ class OIDCConfig(Config):
         if not self.oidc_providers:
             return
 
-        try:
-            check_requirements("oidc")
-        except DependencyException as e:
-            raise ConfigError(
-                e.message  # noqa: B306, DependencyException.message is a property
-            ) from e
+        check_requirements("oidc")
 
         # check we don't have any duplicate idp_ids now. (The SSO handler will also
         # check for duplicates when the REST listeners get registered, but that happens
@@ -66,203 +61,6 @@ class OIDCConfig(Config):
         # OIDC is enabled if we have a provider
         return bool(self.oidc_providers)
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
-        # and login.
-        #
-        # Options for each entry include:
-        #
-        #   idp_id: a unique identifier for this identity provider. Used internally
-        #       by Synapse; should be a single word such as 'github'.
-        #
-        #       Note that, if this is changed, users authenticating via that provider
-        #       will no longer be recognised as the same user!
-        #
-        #       (Use "oidc" here if you are migrating from an old "oidc_config"
-        #       configuration.)
-        #
-        #   idp_name: A user-facing name for this identity provider, which is used to
-        #       offer the user a choice of login mechanisms.
-        #
-        #   idp_icon: An optional icon for this identity provider, which is presented
-        #       by clients and Synapse's own IdP picker page. If given, must be an
-        #       MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
-        #       obtain such an MXC URI is to upload an image to an (unencrypted) room
-        #       and then copy the "url" from the source of the event.)
-        #
-        #   idp_brand: An optional brand for this identity provider, allowing clients
-        #       to style the login flow according to the identity provider in question.
-        #       See the spec for possible options here.
-        #
-        #   discover: set to 'false' to disable the use of the OIDC discovery mechanism
-        #       to discover endpoints. Defaults to true.
-        #
-        #   issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
-        #       is enabled) to discover the provider's endpoints.
-        #
-        #   client_id: Required. oauth2 client id to use.
-        #
-        #   client_secret: oauth2 client secret to use. May be omitted if
-        #        client_secret_jwt_key is given, or if client_auth_method is 'none'.
-        #
-        #   client_secret_jwt_key: Alternative to client_secret: details of a key used
-        #      to create a JSON Web Token to be used as an OAuth2 client secret. If
-        #      given, must be a dictionary with the following properties:
-        #
-        #          key: a pem-encoded signing key. Must be a suitable key for the
-        #              algorithm specified. Required unless 'key_file' is given.
-        #
-        #          key_file: the path to file containing a pem-encoded signing key file.
-        #              Required unless 'key' is given.
-        #
-        #          jwt_header: a dictionary giving properties to include in the JWT
-        #              header. Must include the key 'alg', giving the algorithm used to
-        #              sign the JWT, such as "ES256", using the JWA identifiers in
-        #              RFC7518.
-        #
-        #          jwt_payload: an optional dictionary giving properties to include in
-        #              the JWT payload. Normally this should include an 'iss' key.
-        #
-        #   client_auth_method: auth method to use when exchanging the token. Valid
-        #       values are 'client_secret_basic' (default), 'client_secret_post' and
-        #       'none'.
-        #
-        #   scopes: list of scopes to request. This should normally include the "openid"
-        #       scope. Defaults to ["openid"].
-        #
-        #   authorization_endpoint: the oauth2 authorization endpoint. Required if
-        #       provider discovery is disabled.
-        #
-        #   token_endpoint: the oauth2 token endpoint. Required if provider discovery is
-        #       disabled.
-        #
-        #   userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
-        #       disabled and the 'openid' scope is not requested.
-        #
-        #   jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
-        #       the 'openid' scope is used.
-        #
-        #   skip_verification: set to 'true' to skip metadata verification. Use this if
-        #       you are connecting to a provider that is not OpenID Connect compliant.
-        #       Defaults to false. Avoid this in production.
-        #
-        #   user_profile_method: Whether to fetch the user profile from the userinfo
-        #       endpoint, or to rely on the data returned in the id_token from the
-        #       token_endpoint.
-        #
-        #       Valid values are: 'auto' or 'userinfo_endpoint'.
-        #
-        #       Defaults to 'auto', which uses the userinfo endpoint if 'openid' is
-        #       not included in 'scopes'. Set to 'userinfo_endpoint' to always use the
-        #       userinfo endpoint.
-        #
-        #   allow_existing_users: set to 'true' to allow a user logging in via OIDC to
-        #       match a pre-existing account instead of failing. This could be used if
-        #       switching from password logins to OIDC. Defaults to false.
-        #
-        #   user_mapping_provider: Configuration for how attributes returned from a OIDC
-        #       provider are mapped onto a matrix user. This setting has the following
-        #       sub-properties:
-        #
-        #       module: The class name of a custom mapping module. Default is
-        #           {mapping_provider!r}.
-        #           See https://matrix-org.github.io/synapse/latest/sso_mapping_providers.html#openid-mapping-providers
-        #           for information on implementing a custom mapping provider.
-        #
-        #       config: Configuration for the mapping provider module. This section will
-        #           be passed as a Python dictionary to the user mapping provider
-        #           module's `parse_config` method.
-        #
-        #           For the default provider, the following settings are available:
-        #
-        #             subject_claim: name of the claim containing a unique identifier
-        #                 for the user. Defaults to 'sub', which OpenID Connect
-        #                 compliant providers should provide.
-        #
-        #             localpart_template: Jinja2 template for the localpart of the MXID.
-        #                 If this is not set, the user will be prompted to choose their
-        #                 own username (see the documentation for the
-        #                 'sso_auth_account_details.html' template). This template can
-        #                 use the 'localpart_from_email' filter.
-        #
-        #             confirm_localpart: Whether to prompt the user to validate (or
-        #                 change) the generated localpart (see the documentation for the
-        #                 'sso_auth_account_details.html' template), instead of
-        #                 registering the account right away.
-        #
-        #             display_name_template: Jinja2 template for the display name to set
-        #                 on first login. If unset, no displayname will be set.
-        #
-        #             email_template: Jinja2 template for the email address of the user.
-        #                 If unset, no email address will be added to the account.
-        #
-        #             extra_attributes: a map of Jinja2 templates for extra attributes
-        #                 to send back to the client during login.
-        #                 Note that these are non-standard and clients will ignore them
-        #                 without modifications.
-        #
-        #           When rendering, the Jinja2 templates are given a 'user' variable,
-        #           which is set to the claims returned by the UserInfo Endpoint and/or
-        #           in the ID Token.
-        #
-        #   It is possible to configure Synapse to only allow logins if certain attributes
-        #   match particular values in the OIDC userinfo. The requirements can be listed under
-        #   `attribute_requirements` as shown below. All of the listed attributes must
-        #   match for the login to be permitted. Additional attributes can be added to
-        #   userinfo by expanding the `scopes` section of the OIDC config to retrieve
-        #   additional information from the OIDC provider.
-        #
-        #   If the OIDC claim is a list, then the attribute must match any value in the list.
-        #   Otherwise, it must exactly match the value of the claim. Using the example
-        #   below, the `family_name` claim MUST be "Stephensson", but the `groups`
-        #   claim MUST contain "admin".
-        #
-        #   attribute_requirements:
-        #     - attribute: family_name
-        #       value: "Stephensson"
-        #     - attribute: groups
-        #       value: "admin"
-        #
-        # See https://matrix-org.github.io/synapse/latest/openid.html
-        # for information on how to configure these options.
-        #
-        # For backwards compatibility, it is also possible to configure a single OIDC
-        # provider via an 'oidc_config' setting. This is now deprecated and admins are
-        # advised to migrate to the 'oidc_providers' format. (When doing that migration,
-        # use 'oidc' for the idp_id to ensure that existing users continue to be
-        # recognised.)
-        #
-        oidc_providers:
-          # Generic example
-          #
-          #- idp_id: my_idp
-          #  idp_name: "My OpenID provider"
-          #  idp_icon: "mxc://example.com/mediaid"
-          #  discover: false
-          #  issuer: "https://accounts.example.com/"
-          #  client_id: "provided-by-your-issuer"
-          #  client_secret: "provided-by-your-issuer"
-          #  client_auth_method: client_secret_post
-          #  scopes: ["openid", "profile"]
-          #  authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-          #  token_endpoint: "https://accounts.example.com/oauth2/token"
-          #  userinfo_endpoint: "https://accounts.example.com/userinfo"
-          #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
-          #  skip_verification: true
-          #  user_mapping_provider:
-          #    config:
-          #      subject_claim: "id"
-          #      localpart_template: "{{{{ user.login }}}}"
-          #      display_name_template: "{{{{ user.name }}}}"
-          #      email_template: "{{{{ user.email }}}}"
-          #  attribute_requirements:
-          #    - attribute: userGroup
-          #      value: "synapseUsers"
-        """.format(
-            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
-        )
-
 
 # jsonschema definition of the configuration settings for an oidc identity provider
 OIDC_PROVIDER_CONFIG_SCHEMA = {
@@ -343,7 +141,6 @@ OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
     "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
 }
 
-
 # the `oidc_providers` list can either be None (as it is in the default config), or
 # a list of provider configs, each of which requires an explicit ID and name.
 OIDC_PROVIDER_LIST_SCHEMA = {
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 2e796d1c46..979b128eae 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -49,36 +49,3 @@ class PushConfig(Config):
                 "please set push.include_content instead"
             )
             self.push_include_content = not redact_content
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """
-        ## Push ##
-
-        push:
-          # Clients requesting push notifications can either have the body of
-          # the message sent in the notification poke along with other details
-          # like the sender, or just the event ID and room ID (`event_id_only`).
-          # If clients choose the former, this option controls whether the
-          # notification request includes the content of the event (other details
-          # like the sender are still included). For `event_id_only` push, it
-          # has no effect.
-          #
-          # For modern android devices the notification content will still appear
-          # because it is loaded by the app. iPhone, however will send a
-          # notification saying only that a message arrived and who it came from.
-          #
-          # The default value is "true" to include message details. Uncomment to only
-          # include the event ID and room ID in push notification payloads.
-          #
-          #include_content: false
-
-          # When a push notification is received, an unread count is also sent.
-          # This number can either be calculated as the number of unread messages
-          # for the user, or the number of *rooms* the user has unread messages in.
-          #
-          # The default value is "true", meaning push clients will see the number of
-          # rooms with unread messages in them. Uncomment to instead send the number
-          # of unread messages.
-          #
-          #group_unread_count_by_room: false
-        """
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 0587f5c10f..1ed001e105 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -21,7 +21,7 @@ from synapse.types import JsonDict
 from ._base import Config
 
 
-class RateLimitConfig:
+class RatelimitSettings:
     def __init__(
         self,
         config: Dict[str, float],
@@ -34,7 +34,7 @@ class RateLimitConfig:
 
 
 @attr.s(auto_attribs=True)
-class FederationRateLimitConfig:
+class FederationRatelimitSettings:
     window_size: int = 1000
     sleep_limit: int = 10
     sleep_delay: int = 500
@@ -50,11 +50,11 @@ class RatelimitConfig(Config):
         # Load the new-style messages config if it exists. Otherwise fall back
         # to the old method.
         if "rc_message" in config:
-            self.rc_message = RateLimitConfig(
+            self.rc_message = RatelimitSettings(
                 config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
             )
         else:
-            self.rc_message = RateLimitConfig(
+            self.rc_message = RatelimitSettings(
                 {
                     "per_second": config.get("rc_messages_per_second", 0.2),
                     "burst_count": config.get("rc_message_burst_count", 10.0),
@@ -64,9 +64,9 @@ class RatelimitConfig(Config):
         # Load the new-style federation config, if it exists. Otherwise, fall
         # back to the old method.
         if "rc_federation" in config:
-            self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
+            self.rc_federation = FederationRatelimitSettings(**config["rc_federation"])
         else:
-            self.rc_federation = FederationRateLimitConfig(
+            self.rc_federation = FederationRatelimitSettings(
                 **{
                     k: v
                     for k, v in {
@@ -80,17 +80,17 @@ class RatelimitConfig(Config):
                 }
             )
 
-        self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+        self.rc_registration = RatelimitSettings(config.get("rc_registration", {}))
 
-        self.rc_registration_token_validity = RateLimitConfig(
+        self.rc_registration_token_validity = RatelimitSettings(
             config.get("rc_registration_token_validity", {}),
             defaults={"per_second": 0.1, "burst_count": 5},
         )
 
         rc_login_config = config.get("rc_login", {})
-        self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
-        self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
-        self.rc_login_failed_attempts = RateLimitConfig(
+        self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {}))
+        self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {}))
+        self.rc_login_failed_attempts = RatelimitSettings(
             rc_login_config.get("failed_attempts", {})
         )
 
@@ -101,167 +101,57 @@ class RatelimitConfig(Config):
         rc_admin_redaction = config.get("rc_admin_redaction")
         self.rc_admin_redaction = None
         if rc_admin_redaction:
-            self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+            self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction)
 
-        self.rc_joins_local = RateLimitConfig(
+        self.rc_joins_local = RatelimitSettings(
             config.get("rc_joins", {}).get("local", {}),
             defaults={"per_second": 0.1, "burst_count": 10},
         )
-        self.rc_joins_remote = RateLimitConfig(
+        self.rc_joins_remote = RatelimitSettings(
             config.get("rc_joins", {}).get("remote", {}),
             defaults={"per_second": 0.01, "burst_count": 10},
         )
 
+        # Track the rate of joins to a given room. If there are too many, temporarily
+        # prevent local joins and remote joins via this server.
+        self.rc_joins_per_room = RatelimitSettings(
+            config.get("rc_joins_per_room", {}),
+            defaults={"per_second": 1, "burst_count": 10},
+        )
+
         # Ratelimit cross-user key requests:
         # * For local requests this is keyed by the sending device.
         # * For requests received over federation this is keyed by the origin.
         #
         # Note that this isn't exposed in the configuration as it is obscure.
-        self.rc_key_requests = RateLimitConfig(
+        self.rc_key_requests = RatelimitSettings(
             config.get("rc_key_requests", {}),
             defaults={"per_second": 20, "burst_count": 100},
         )
 
-        self.rc_3pid_validation = RateLimitConfig(
+        self.rc_3pid_validation = RatelimitSettings(
             config.get("rc_3pid_validation") or {},
             defaults={"per_second": 0.003, "burst_count": 5},
         )
 
-        self.rc_invites_per_room = RateLimitConfig(
+        self.rc_invites_per_room = RatelimitSettings(
             config.get("rc_invites", {}).get("per_room", {}),
             defaults={"per_second": 0.3, "burst_count": 10},
         )
-        self.rc_invites_per_user = RateLimitConfig(
+        self.rc_invites_per_user = RatelimitSettings(
             config.get("rc_invites", {}).get("per_user", {}),
             defaults={"per_second": 0.003, "burst_count": 5},
         )
 
-        self.rc_third_party_invite = RateLimitConfig(
+        self.rc_invites_per_issuer = RatelimitSettings(
+            config.get("rc_invites", {}).get("per_issuer", {}),
+            defaults={"per_second": 0.3, "burst_count": 10},
+        )
+
+        self.rc_third_party_invite = RatelimitSettings(
             config.get("rc_third_party_invite", {}),
             defaults={
                 "per_second": self.rc_message.per_second,
                 "burst_count": self.rc_message.burst_count,
             },
         )
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Ratelimiting ##
-
-        # Ratelimiting settings for client actions (registration, login, messaging).
-        #
-        # Each ratelimiting configuration is made of two parameters:
-        #   - per_second: number of requests a client can send per second.
-        #   - burst_count: number of requests a client can send before being throttled.
-        #
-        # Synapse currently uses the following configurations:
-        #   - one for messages that ratelimits sending based on the account the client
-        #     is using
-        #   - one for registration that ratelimits registration requests based on the
-        #     client's IP address.
-        #   - one for checking the validity of registration tokens that ratelimits
-        #     requests based on the client's IP address.
-        #   - one for login that ratelimits login requests based on the client's IP
-        #     address.
-        #   - one for login that ratelimits login requests based on the account the
-        #     client is attempting to log into.
-        #   - one for login that ratelimits login requests based on the account the
-        #     client is attempting to log into, based on the amount of failed login
-        #     attempts for this account.
-        #   - one for ratelimiting redactions by room admins. If this is not explicitly
-        #     set then it uses the same ratelimiting as per rc_message. This is useful
-        #     to allow room admins to deal with abuse quickly.
-        #   - two for ratelimiting number of rooms a user can join, "local" for when
-        #     users are joining rooms the server is already in (this is cheap) vs
-        #     "remote" for when users are trying to join rooms not on the server (which
-        #     can be more expensive)
-        #   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
-        #   - two for ratelimiting how often invites can be sent in a room or to a
-        #     specific user.
-        #   - one for ratelimiting 3PID invites (i.e. invites sent to a third-party ID
-        #     such as an email address or a phone number) based on the account that's
-        #     sending the invite.
-        #
-        # The defaults are as shown below.
-        #
-        #rc_message:
-        #  per_second: 0.2
-        #  burst_count: 10
-        #
-        #rc_registration:
-        #  per_second: 0.17
-        #  burst_count: 3
-        #
-        #rc_registration_token_validity:
-        #  per_second: 0.1
-        #  burst_count: 5
-        #
-        #rc_login:
-        #  address:
-        #    per_second: 0.17
-        #    burst_count: 3
-        #  account:
-        #    per_second: 0.17
-        #    burst_count: 3
-        #  failed_attempts:
-        #    per_second: 0.17
-        #    burst_count: 3
-        #
-        #rc_admin_redaction:
-        #  per_second: 1
-        #  burst_count: 50
-        #
-        #rc_joins:
-        #  local:
-        #    per_second: 0.1
-        #    burst_count: 10
-        #  remote:
-        #    per_second: 0.01
-        #    burst_count: 10
-        #
-        #rc_3pid_validation:
-        #  per_second: 0.003
-        #  burst_count: 5
-        #
-        #rc_invites:
-        #  per_room:
-        #    per_second: 0.3
-        #    burst_count: 10
-        #  per_user:
-        #    per_second: 0.003
-        #    burst_count: 5
-        #
-        #rc_third_party_invite:
-        #  per_second: 0.2
-        #  burst_count: 10
-
-        # Ratelimiting settings for incoming federation
-        #
-        # The rc_federation configuration is made up of the following settings:
-        #   - window_size: window size in milliseconds
-        #   - sleep_limit: number of federation requests from a single server in
-        #     a window before the server will delay processing the request.
-        #   - sleep_delay: duration in milliseconds to delay processing events
-        #     from remote servers by if they go over the sleep limit.
-        #   - reject_limit: maximum number of concurrent federation requests
-        #     allowed from a single server
-        #   - concurrent: number of federation requests to concurrently process
-        #     from a single server
-        #
-        # The defaults are as shown below.
-        #
-        #rc_federation:
-        #  window_size: 1000
-        #  sleep_limit: 10
-        #  sleep_delay: 500
-        #  reject_limit: 50
-        #  concurrent: 3
-
-        # Target outgoing federation transaction frequency for sending read-receipts,
-        # per-room.
-        #
-        # If we end up trying to send out more read-receipts, they will get buffered up
-        # into fewer transactions.
-        #
-        #federation_rr_transactions_per_room_per_second: 50
-        """
diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index ec7a735418..b42dd2e93a 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -34,24 +34,3 @@ class RedisConfig(Config):
         self.redis_host = redis_config.get("host", "localhost")
         self.redis_port = redis_config.get("port", 6379)
         self.redis_password = redis_config.get("password")
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # Configuration for Redis when using workers. This *must* be enabled when
-        # using workers (unless using old style direct TCP configuration).
-        #
-        redis:
-          # Uncomment the below to enable Redis support.
-          #
-          #enabled: true
-
-          # Optional host and port to use to connect to redis. Defaults to
-          # localhost and 6379
-          #
-          #host: localhost
-          #port: 6379
-
-          # Optional password if configured on the Redis instance
-          #
-          #password: <secret_password>
-        """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index d2d0425e62..df1d83dfaa 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,13 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
-from typing import Any, Optional
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import RoomCreationPreset
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config, ConfigError, read_file
 from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string_with_symbols, strtobool
 
+NO_EMAIL_DELEGATE_ERROR = """\
+Delegation of email verification to an identity server is no longer supported. To
+continue to allow users to add email addresses to their accounts, and use them for
+password resets, configure Synapse with an SMTP server via the `email` setting, and
+remove `account_threepid_delegates.email`.
+"""
+
+CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+You have configured both `registration_shared_secret` and
+`registration_shared_secret_path`. These are mutually incompatible.
+"""
+
 
 class RegistrationConfig(Config):
     section = "registration"
@@ -46,12 +58,22 @@ class RegistrationConfig(Config):
         self.enable_registration_token_3pid_bypass = config.get(
             "enable_registration_token_3pid_bypass", False
         )
+
+        # read the shared secret, either inline or from an external file
         self.registration_shared_secret = config.get("registration_shared_secret")
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path:
+            if self.registration_shared_secret:
+                raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR)
+            self.registration_shared_secret = read_file(
+                registration_shared_secret_path, ("registration_shared_secret_path",)
+            ).strip()
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
 
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
-        self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+        if "email" in account_threepid_delegates:
+            raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
         self.default_identity_server = config.get("default_identity_server")
         self.allow_guest_access = config.get("allow_guest_access", False)
@@ -206,284 +228,24 @@ class RegistrationConfig(Config):
             registration_shared_secret = 'registration_shared_secret: "%s"' % (
                 random_string_with_symbols(50),
             )
+            return registration_shared_secret
         else:
-            registration_shared_secret = "#registration_shared_secret: <PRIVATE STRING>"
-
-        return (
-            """\
-        ## Registration ##
-        #
-        # Registration can be rate-limited using the parameters in the "Ratelimiting"
-        # section of this file.
-
-        # Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
-        # you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
-        # without any verification, you must also set `enable_registration_without_verification`, found below.
-        #
-        #enable_registration: false
-
-        # Enable registration without email or captcha verification. Note: this option is *not* recommended,
-        # as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
-        # unless `enable_registration` is also enabled.
-        #
-        #enable_registration_without_verification: true
-
-        # Time that a user's session remains valid for, after they log in.
-        #
-        # Note that this is not currently compatible with guest logins.
-        #
-        # Note also that this is calculated at login time: changes are not applied
-        # retrospectively to users who have already logged in.
-        #
-        # By default, this is infinite.
-        #
-        #session_lifetime: 24h
-
-        # Time that an access token remains valid for, if the session is
-        # using refresh tokens.
-        # For more information about refresh tokens, please see the manual.
-        # Note that this only applies to clients which advertise support for
-        # refresh tokens.
-        #
-        # Note also that this is calculated at login time and refresh time:
-        # changes are not applied to existing sessions until they are refreshed.
-        #
-        # By default, this is 5 minutes.
-        #
-        #refreshable_access_token_lifetime: 5m
-
-        # Time that a refresh token remains valid for (provided that it is not
-        # exchanged for another one first).
-        # This option can be used to automatically log-out inactive sessions.
-        # Please see the manual for more information.
-        #
-        # Note also that this is calculated at login time and refresh time:
-        # changes are not applied to existing sessions until they are refreshed.
-        #
-        # By default, this is infinite.
-        #
-        #refresh_token_lifetime: 24h
-
-        # Time that an access token remains valid for, if the session is NOT
-        # using refresh tokens.
-        # Please note that not all clients support refresh tokens, so setting
-        # this to a short value may be inconvenient for some users who will
-        # then be logged out frequently.
-        #
-        # Note also that this is calculated at login time: changes are not applied
-        # retrospectively to existing sessions for users that have already logged in.
-        #
-        # By default, this is infinite.
-        #
-        #nonrefreshable_access_token_lifetime: 24h
-
-        # The user must provide all of the below types of 3PID when registering.
-        #
-        #registrations_require_3pid:
-        #  - email
-        #  - msisdn
-
-        # Explicitly disable asking for MSISDNs from the registration
-        # flow (overrides registrations_require_3pid if MSISDNs are set as required)
-        #
-        #disable_msisdn_registration: true
-
-        # Mandate that users are only allowed to associate certain formats of
-        # 3PIDs with accounts on this server.
-        #
-        #allowed_local_3pids:
-        #  - medium: email
-        #    pattern: '^[^@]+@matrix\\.org$'
-        #  - medium: email
-        #    pattern: '^[^@]+@vector\\.im$'
-        #  - medium: msisdn
-        #    pattern: '\\+44'
-
-        # Enable 3PIDs lookup requests to identity servers from this server.
-        #
-        #enable_3pid_lookup: true
-
-        # Require users to submit a token during registration.
-        # Tokens can be managed using the admin API:
-        # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
-        # Note that `enable_registration` must be set to `true`.
-        # Disabling this option will not delete any tokens previously generated.
-        # Defaults to false. Uncomment the following to require tokens:
-        #
-        #registration_requires_token: true
-
-        # Allow users to submit a token during registration to bypass any required 3pid
-        # steps configured in `registrations_require_3pid`.
-        # Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
-        #
-        #enable_registration_token_3pid_bypass: false
-
-        # If set, allows registration of standard or admin accounts by anyone who
-        # has the shared secret, even if registration is otherwise disabled.
-        #
-        %(registration_shared_secret)s
-
-        # Set the number of bcrypt rounds used to generate password hash.
-        # Larger numbers increase the work factor needed to generate the hash.
-        # The default number is 12 (which equates to 2^12 rounds).
-        # N.B. that increasing this will exponentially increase the time required
-        # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
-        #
-        #bcrypt_rounds: 12
-
-        # Allows users to register as guests without a password/email/etc, and
-        # participate in rooms hosted on this server which have been made
-        # accessible to anonymous users.
-        #
-        #allow_guest_access: false
-
-        # The identity server which we suggest that clients should use when users log
-        # in on this server.
-        #
-        # (By default, no suggestion is made, so it is left up to the client.
-        # This setting is ignored unless public_baseurl is also explicitly set.)
-        #
-        #default_identity_server: https://matrix.org
-
-        # Handle threepid (email/phone etc) registration and password resets through a set of
-        # *trusted* identity servers. Note that this allows the configured identity server to
-        # reset passwords for accounts!
-        #
-        # Be aware that if `email` is not set, and SMTP options have not been
-        # configured in the email config block, registration and user password resets via
-        # email will be globally disabled.
-        #
-        # Additionally, if `msisdn` is not set, registration and password resets via msisdn
-        # will be disabled regardless, and users will not be able to associate an msisdn
-        # identifier to their account. This is due to Synapse currently not supporting
-        # any method of sending SMS messages on its own.
-        #
-        # To enable using an identity server for operations regarding a particular third-party
-        # identifier type, set the value to the URL of that identity server as shown in the
-        # examples below.
-        #
-        # Servers handling the these requests must answer the `/requestToken` endpoints defined
-        # by the Matrix Identity Service API specification:
-        # https://matrix.org/docs/spec/identity_service/latest
-        #
-        account_threepid_delegates:
-            #email: https://example.com     # Delegate email sending to example.com
-            #msisdn: http://localhost:8090  # Delegate SMS sending to this local process
-
-        # Whether users are allowed to change their displayname after it has
-        # been initially set. Useful when provisioning users based on the
-        # contents of a third-party directory.
-        #
-        # Does not apply to server administrators. Defaults to 'true'
-        #
-        #enable_set_displayname: false
-
-        # Whether users are allowed to change their avatar after it has been
-        # initially set. Useful when provisioning users based on the contents
-        # of a third-party directory.
-        #
-        # Does not apply to server administrators. Defaults to 'true'
-        #
-        #enable_set_avatar_url: false
-
-        # Whether users can change the 3PIDs associated with their accounts
-        # (email address and msisdn).
-        #
-        # Defaults to 'true'
-        #
-        #enable_3pid_changes: false
-
-        # Users who register on this homeserver will automatically be joined
-        # to these rooms.
-        #
-        # By default, any room aliases included in this list will be created
-        # as a publicly joinable room when the first user registers for the
-        # homeserver. This behaviour can be customised with the settings below.
-        # If the room already exists, make certain it is a publicly joinable
-        # room. The join rule of the room must be set to 'public'.
-        #
-        #auto_join_rooms:
-        #  - "#example:example.com"
-
-        # Where auto_join_rooms are specified, setting this flag ensures that the
-        # the rooms exist by creating them when the first user on the
-        # homeserver registers.
-        #
-        # By default the auto-created rooms are publicly joinable from any federated
-        # server. Use the autocreate_auto_join_rooms_federated and
-        # autocreate_auto_join_room_preset settings below to customise this behaviour.
-        #
-        # Setting to false means that if the rooms are not manually created,
-        # users cannot be auto-joined since they do not exist.
-        #
-        # Defaults to true. Uncomment the following line to disable automatically
-        # creating auto-join rooms.
-        #
-        #autocreate_auto_join_rooms: false
-
-        # Whether the auto_join_rooms that are auto-created are available via
-        # federation. Only has an effect if autocreate_auto_join_rooms is true.
-        #
-        # Note that whether a room is federated cannot be modified after
-        # creation.
-        #
-        # Defaults to true: the room will be joinable from other servers.
-        # Uncomment the following to prevent users from other homeservers from
-        # joining these rooms.
-        #
-        #autocreate_auto_join_rooms_federated: false
-
-        # The room preset to use when auto-creating one of auto_join_rooms. Only has an
-        # effect if autocreate_auto_join_rooms is true.
-        #
-        # This can be one of "public_chat", "private_chat", or "trusted_private_chat".
-        # If a value of "private_chat" or "trusted_private_chat" is used then
-        # auto_join_mxid_localpart must also be configured.
-        #
-        # Defaults to "public_chat", meaning that the room is joinable by anyone, including
-        # federated servers if autocreate_auto_join_rooms_federated is true (the default).
-        # Uncomment the following to require an invitation to join these rooms.
-        #
-        #autocreate_auto_join_room_preset: private_chat
-
-        # The local part of the user id which is used to create auto_join_rooms if
-        # autocreate_auto_join_rooms is true. If this is not provided then the
-        # initial user account that registers will be used to create the rooms.
-        #
-        # The user id is also used to invite new users to any auto-join rooms which
-        # are set to invite-only.
-        #
-        # It *must* be configured if autocreate_auto_join_room_preset is set to
-        # "private_chat" or "trusted_private_chat".
-        #
-        # Note that this must be specified in order for new users to be correctly
-        # invited to any auto-join rooms which have been set to invite-only (either
-        # at the time of creation or subsequently).
-        #
-        # Note that, if the room already exists, this user must be joined and
-        # have the appropriate permissions to invite new members.
-        #
-        #auto_join_mxid_localpart: system
-
-        # When auto_join_rooms is specified, setting this flag to false prevents
-        # guest accounts from being automatically joined to the rooms.
-        #
-        # Defaults to true.
-        #
-        #auto_join_rooms_for_guests: false
-
-        # Whether to inhibit errors raised when registering a new account if the user ID
-        # already exists. If turned on, that requests to /register/available will always
-        # show a user ID as available, and Synapse won't raise an error when starting
-        # a registration with a user ID that already exists. However, Synapse will still
-        # raise an error if the registration completes and the username conflicts.
-        #
-        # Defaults to false.
-        #
-        #inhibit_user_in_use_error: true
-        """
-            % locals()
-        )
+            return ""
+
+    def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
+        # if 'registration_shared_secret_path' is specified, and the target file
+        # does not exist, generate it.
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path and not self.path_exists(
+            registration_shared_secret_path
+        ):
+            print(
+                "Generating registration shared secret file "
+                + registration_shared_secret_path
+            )
+            secret = random_string_with_symbols(50)
+            with open(registration_shared_secret_path, "w") as f:
+                f.write(f"{secret}\n")
 
     @staticmethod
     def add_arguments(parser: argparse.ArgumentParser) -> None:
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index f9c55143c3..1033496bb4 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -19,9 +19,9 @@ from urllib.request import getproxies_environment  # type: ignore
 
 import attr
 
-from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
+from synapse.config.server import generate_ip_set
 from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
 from synapse.util.module_loader import load_module
 
 from ._base import Config, ConfigError
@@ -42,6 +42,18 @@ THUMBNAIL_SIZE_YAML = """\
         #    method: %(method)s
 """
 
+# A map from the given media type to the type of thumbnail we should generate
+# for it.
+THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
+    "image/jpeg": "jpeg",
+    "image/jpg": "jpeg",
+    "image/webp": "jpeg",
+    # Thumbnails can only be jpeg or png. We choose png thumbnails for gif
+    # because it can have transparency.
+    "image/gif": "png",
+    "image/png": "png",
+}
+
 HTTP_PROXY_SET_WARNING = """\
 The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
 
@@ -79,13 +91,22 @@ def parse_thumbnail_requirements(
         width = size["width"]
         height = size["height"]
         method = size["method"]
-        jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
-        png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
-        requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/gif", []).append(png_thumbnail)
-        requirements.setdefault("image/png", []).append(png_thumbnail)
+
+        for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
+            requirement = requirements.setdefault(format, [])
+            if thumbnail_format == "jpeg":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/jpeg")
+                )
+            elif thumbnail_format == "png":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/png")
+                )
+            else:
+                raise Exception(
+                    "Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
+                    % (format, thumbnail_format)
+                )
     return {
         media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
     }
@@ -184,13 +205,7 @@ class ContentRepositoryConfig(Config):
         )
         self.url_preview_enabled = config.get("url_preview_enabled", False)
         if self.url_preview_enabled:
-            try:
-                check_requirements("url_preview")
-
-            except DependencyException as e:
-                raise ConfigError(
-                    e.message  # noqa: B306, DependencyException.message is a property
-                )
+            check_requirements("url_preview")
 
             proxy_env = getproxies_environment()
             if "url_preview_ip_range_blacklist" not in config:
@@ -242,166 +257,4 @@ class ContentRepositoryConfig(Config):
     def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
         assert data_dir_path is not None
         media_store = os.path.join(data_dir_path, "media_store")
-
-        formatted_thumbnail_sizes = "".join(
-            THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES
-        )
-        # strip final NL
-        formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
-
-        ip_range_blacklist = "\n".join(
-            "        #  - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
-        )
-
-        return (
-            r"""
-        ## Media Store ##
-
-        # Enable the media store service in the Synapse master. Uncomment the
-        # following if you are using a separate media store worker.
-        #
-        #enable_media_repo: false
-
-        # Directory where uploaded images and attachments are stored.
-        #
-        media_store_path: "%(media_store)s"
-
-        # Media storage providers allow media to be stored in different
-        # locations.
-        #
-        #media_storage_providers:
-        #  - module: file_system
-        #    # Whether to store newly uploaded local files
-        #    store_local: false
-        #    # Whether to store newly downloaded remote files
-        #    store_remote: false
-        #    # Whether to wait for successful storage for local uploads
-        #    store_synchronous: false
-        #    config:
-        #       directory: /mnt/some/other/directory
-
-        # The largest allowed upload size in bytes
-        #
-        # If you are using a reverse proxy you may also need to set this value in
-        # your reverse proxy's config. Notably Nginx has a small max body size by default.
-        # See https://matrix-org.github.io/synapse/latest/reverse_proxy.html.
-        #
-        #max_upload_size: 50M
-
-        # Maximum number of pixels that will be thumbnailed
-        #
-        #max_image_pixels: 32M
-
-        # Whether to generate new thumbnails on the fly to precisely match
-        # the resolution requested by the client. If true then whenever
-        # a new resolution is requested by the client the server will
-        # generate a new thumbnail. If false the server will pick a thumbnail
-        # from a precalculated list.
-        #
-        #dynamic_thumbnails: false
-
-        # List of thumbnails to precalculate when an image is uploaded.
-        #
-        #thumbnail_sizes:
-%(formatted_thumbnail_sizes)s
-
-        # Is the preview URL API enabled?
-        #
-        # 'false' by default: uncomment the following to enable it (and specify a
-        # url_preview_ip_range_blacklist blacklist).
-        #
-        #url_preview_enabled: true
-
-        # List of IP address CIDR ranges that the URL preview spider is denied
-        # from accessing.  There are no defaults: you must explicitly
-        # specify a list for URL previewing to work.  You should specify any
-        # internal services in your network that you do not want synapse to try
-        # to connect to, otherwise anyone in any Matrix room could cause your
-        # synapse to issue arbitrary GET requests to your internal services,
-        # causing serious security issues.
-        #
-        # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-        # listed here, since they correspond to unroutable addresses.)
-        #
-        # This must be specified if url_preview_enabled is set. It is recommended that
-        # you uncomment the following list as a starting point.
-        #
-        # Note: The value is ignored when an HTTP proxy is in use
-        #
-        #url_preview_ip_range_blacklist:
-%(ip_range_blacklist)s
-
-        # List of IP address CIDR ranges that the URL preview spider is allowed
-        # to access even if they are specified in url_preview_ip_range_blacklist.
-        # This is useful for specifying exceptions to wide-ranging blacklisted
-        # target IP ranges - e.g. for enabling URL previews for a specific private
-        # website only visible in your network.
-        #
-        #url_preview_ip_range_whitelist:
-        #   - '192.168.1.1'
-
-        # Optional list of URL matches that the URL preview spider is
-        # denied from accessing.  You should use url_preview_ip_range_blacklist
-        # in preference to this, otherwise someone could define a public DNS
-        # entry that points to a private IP address and circumvent the blacklist.
-        # This is more useful if you know there is an entire shape of URL that
-        # you know that will never want synapse to try to spider.
-        #
-        # Each list entry is a dictionary of url component attributes as returned
-        # by urlparse.urlsplit as applied to the absolute form of the URL.  See
-        # https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
-        # The values of the dictionary are treated as an filename match pattern
-        # applied to that component of URLs, unless they start with a ^ in which
-        # case they are treated as a regular expression match.  If all the
-        # specified component matches for a given list item succeed, the URL is
-        # blacklisted.
-        #
-        #url_preview_url_blacklist:
-        #  # blacklist any URL with a username in its URI
-        #  - username: '*'
-        #
-        #  # blacklist all *.google.com URLs
-        #  - netloc: 'google.com'
-        #  - netloc: '*.google.com'
-        #
-        #  # blacklist all plain HTTP URLs
-        #  - scheme: 'http'
-        #
-        #  # blacklist http(s)://www.acme.com/foo
-        #  - netloc: 'www.acme.com'
-        #    path: '/foo'
-        #
-        #  # blacklist any URL with a literal IPv4 address
-        #  - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
-
-        # The largest allowed URL preview spidering size in bytes
-        #
-        #max_spider_size: 10M
-
-        # A list of values for the Accept-Language HTTP header used when
-        # downloading webpages during URL preview generation. This allows
-        # Synapse to specify the preferred languages that URL previews should
-        # be in when communicating with remote servers.
-        #
-        # Each value is a IETF language tag; a 2-3 letter identifier for a
-        # language, optionally followed by subtags separated by '-', specifying
-        # a country or region variant.
-        #
-        # Multiple values can be provided, and a weight can be added to each by
-        # using quality value syntax (;q=). '*' translates to any language.
-        #
-        # Defaults to "en".
-        #
-        # Example:
-        #
-        # url_preview_accept_language:
-        #   - en-UK
-        #   - en-US;q=0.9
-        #   - fr;q=0.8
-        #   - *;q=0.7
-        #
-        url_preview_accept_language:
-        #   - en
-        """
-            % locals()
-        )
+        return f"media_store_path: {media_store}"
diff --git a/synapse/config/retention.py b/synapse/config/retention.py
index 03b723b84b..033051a9c2 100644
--- a/synapse/config/retention.py
+++ b/synapse/config/retention.py
@@ -153,75 +153,3 @@ class RetentionConfig(Config):
             self.retention_purge_jobs = [
                 RetentionPurgeJob(self.parse_duration("1d"), None, None)
             ]
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # Message retention policy at the server level.
-        #
-        # Room admins and mods can define a retention period for their rooms using the
-        # 'm.room.retention' state event, and server admins can cap this period by setting
-        # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
-        #
-        # If this feature is enabled, Synapse will regularly look for and purge events
-        # which are older than the room's maximum retention period. Synapse will also
-        # filter events received over federation so that events that should have been
-        # purged are ignored and not stored again.
-        #
-        retention:
-          # The message retention policies feature is disabled by default. Uncomment the
-          # following line to enable it.
-          #
-          #enabled: true
-
-          # Default retention policy. If set, Synapse will apply it to rooms that lack the
-          # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
-          # matter much because Synapse doesn't take it into account yet.
-          #
-          #default_policy:
-          #  min_lifetime: 1d
-          #  max_lifetime: 1y
-
-          # Retention policy limits. If set, and the state of a room contains a
-          # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
-          # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
-          # to these limits when running purge jobs.
-          #
-          #allowed_lifetime_min: 1d
-          #allowed_lifetime_max: 1y
-
-          # Server admins can define the settings of the background jobs purging the
-          # events which lifetime has expired under the 'purge_jobs' section.
-          #
-          # If no configuration is provided, a single job will be set up to delete expired
-          # events in every room daily.
-          #
-          # Each job's configuration defines which range of message lifetimes the job
-          # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
-          # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
-          # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
-          # lower than or equal to 3 days. Both the minimum and the maximum value of a
-          # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
-          # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
-          # which 'max_lifetime' is lower than or equal to three days.
-          #
-          # The rationale for this per-job configuration is that some rooms might have a
-          # retention policy with a low 'max_lifetime', where history needs to be purged
-          # of outdated messages on a more frequent basis than for the rest of the rooms
-          # (e.g. every 12h), but not want that purge to be performed by a job that's
-          # iterating over every room it knows, which could be heavy on the server.
-          #
-          # If any purge job is configured, it is strongly recommended to have at least
-          # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
-          # set, or one job without 'shortest_max_lifetime' and one job without
-          # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
-          # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
-          # room's policy to these values is done after the policies are retrieved from
-          # Synapse's database (which is done using the range specified in a purge job's
-          # configuration).
-          #
-          #purge_jobs:
-          #  - longest_max_lifetime: 3d
-          #    interval: 12h
-          #  - shortest_max_lifetime: 3d
-          #    interval: 1d
-        """
diff --git a/synapse/config/room.py b/synapse/config/room.py
index 462d85ac1d..4a7ac00540 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -75,59 +75,3 @@ class RoomConfig(Config):
                         % preset
                     )
                 # We validate the actual overrides when we try to apply them.
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Rooms ##
-
-        # Controls whether locally-created rooms should be end-to-end encrypted by
-        # default.
-        #
-        # Possible options are "all", "invite", and "off". They are defined as:
-        #
-        # * "all": any locally-created room
-        # * "invite": any room created with the "private_chat" or "trusted_private_chat"
-        #             room creation presets
-        # * "off": this option will take no effect
-        #
-        # The default value is "off".
-        #
-        # Note that this option will only affect rooms created after it is set. It
-        # will also not affect rooms created by other servers.
-        #
-        #encryption_enabled_by_default_for_room_type: invite
-
-        # Override the default power levels for rooms created on this server, per
-        # room creation preset.
-        #
-        # The appropriate dictionary for the room preset will be applied on top
-        # of the existing power levels content.
-        #
-        # Useful if you know that your users need special permissions in rooms
-        # that they create (e.g. to send particular types of state events without
-        # needing an elevated power level).  This takes the same shape as the
-        # `power_level_content_override` parameter in the /createRoom API, but
-        # is applied before that parameter.
-        #
-        # Valid keys are some or all of `private_chat`, `trusted_private_chat`
-        # and `public_chat`. Inside each of those should be any of the
-        # properties allowed in `power_level_content_override` in the
-        # /createRoom API. If any property is missing, its default value will
-        # continue to be used. If any property is present, it will overwrite
-        # the existing default completely (so if the `events` property exists,
-        # the default event power levels will be ignored).
-        #
-        #default_power_level_content_override:
-        #    private_chat:
-        #        "events":
-        #            "com.example.myeventtype" : 0
-        #            "m.room.avatar": 50
-        #            "m.room.canonical_alias": 50
-        #            "m.room.encryption": 100
-        #            "m.room.history_visibility": 100
-        #            "m.room.name": 50
-        #            "m.room.power_levels": 100
-        #            "m.room.server_acl": 100
-        #            "m.room.tombstone": 100
-        #        "events_default": 1
-        """
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 717ba70e1c..3ed236217f 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -52,72 +52,6 @@ class RoomDirectoryConfig(Config):
                 _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
             ]
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """
-        # Uncomment to disable searching the public room list. When disabled
-        # blocks searching local and remote room lists for local and remote
-        # users by always returning an empty list for all queries.
-        #
-        #enable_room_list_search: false
-
-        # The `alias_creation` option controls who's allowed to create aliases
-        # on this server.
-        #
-        # The format of this option is a list of rules that contain globs that
-        # match against user_id, room_id and the new alias (fully qualified with
-        # server name). The action in the first rule that matches is taken,
-        # which can currently either be "allow" or "deny".
-        #
-        # Missing user_id/room_id/alias fields default to "*".
-        #
-        # If no rules match the request is denied. An empty list means no one
-        # can create aliases.
-        #
-        # Options for the rules include:
-        #
-        #   user_id: Matches against the creator of the alias
-        #   alias: Matches against the alias being created
-        #   room_id: Matches against the room ID the alias is being pointed at
-        #   action: Whether to "allow" or "deny" the request if the rule matches
-        #
-        # The default is:
-        #
-        #alias_creation_rules:
-        #  - user_id: "*"
-        #    alias: "*"
-        #    room_id: "*"
-        #    action: allow
-
-        # The `room_list_publication_rules` option controls who can publish and
-        # which rooms can be published in the public room list.
-        #
-        # The format of this option is the same as that for
-        # `alias_creation_rules`.
-        #
-        # If the room has one or more aliases associated with it, only one of
-        # the aliases needs to match the alias rule. If there are no aliases
-        # then only rules with `alias: *` match.
-        #
-        # If no rules match the request is denied. An empty list means no one
-        # can publish rooms.
-        #
-        # Options for the rules include:
-        #
-        #   user_id: Matches against the creator of the alias
-        #   room_id: Matches against the room ID being published
-        #   alias: Matches against any current local or canonical aliases
-        #            associated with the room
-        #   action: Whether to "allow" or "deny" the request if the rule matches
-        #
-        # The default is:
-        #
-        #room_list_publication_rules:
-        #  - user_id: "*"
-        #    alias: "*"
-        #    room_id: "*"
-        #    action: allow
-        """
-
     def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool:
         """Checks if the given user is allowed to create the given alias
 
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 19b2f1b215..49ca663dde 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -18,7 +18,7 @@ from typing import Any, List, Set
 
 from synapse.config.sso import SsoAttributeRequirement
 from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
 from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
@@ -76,12 +76,7 @@ class SAML2Config(Config):
         if not saml2_config.get("sp_config") and not saml2_config.get("config_path"):
             return
 
-        try:
-            check_requirements("saml2")
-        except DependencyException as e:
-            raise ConfigError(
-                e.message  # noqa: B306, DependencyException.message is a property
-            )
+        check_requirements("saml2")
 
         self.saml2_enabled = True
 
@@ -223,189 +218,6 @@ class SAML2Config(Config):
             },
         }
 
-    def generate_config_section(self, config_dir_path: str, **kwargs: Any) -> str:
-        return """\
-        ## Single sign-on integration ##
-
-        # The following settings can be used to make Synapse use a single sign-on
-        # provider for authentication, instead of its internal password database.
-        #
-        # You will probably also want to set the following options to `false` to
-        # disable the regular login/registration flows:
-        #   * enable_registration
-        #   * password_config.enabled
-        #
-        # You will also want to investigate the settings under the "sso" configuration
-        # section below.
-
-        # Enable SAML2 for registration and login. Uses pysaml2.
-        #
-        # At least one of `sp_config` or `config_path` must be set in this section to
-        # enable SAML login.
-        #
-        # Once SAML support is enabled, a metadata file will be exposed at
-        # https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
-        # use to configure your SAML IdP with. Alternatively, you can manually configure
-        # the IdP to use an ACS location of
-        # https://<server>:<port>/_synapse/client/saml2/authn_response.
-        #
-        saml2_config:
-          # `sp_config` is the configuration for the pysaml2 Service Provider.
-          # See pysaml2 docs for format of config.
-          #
-          # Default values will be used for the 'entityid' and 'service' settings,
-          # so it is not normally necessary to specify them unless you need to
-          # override them.
-          #
-          sp_config:
-            # Point this to the IdP's metadata. You must provide either a local
-            # file via the `local` attribute or (preferably) a URL via the
-            # `remote` attribute.
-            #
-            #metadata:
-            #  local: ["saml2/idp.xml"]
-            #  remote:
-            #    - url: https://our_idp/metadata.xml
-
-            # Allowed clock difference in seconds between the homeserver and IdP.
-            #
-            # Uncomment the below to increase the accepted time difference from 0 to 3 seconds.
-            #
-            #accepted_time_diff: 3
-
-            # By default, the user has to go to our login page first. If you'd like
-            # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
-            # 'service.sp' section:
-            #
-            #service:
-            #  sp:
-            #    allow_unsolicited: true
-
-            # The examples below are just used to generate our metadata xml, and you
-            # may well not need them, depending on your setup. Alternatively you
-            # may need a whole lot more detail - see the pysaml2 docs!
-
-            #description: ["My awesome SP", "en"]
-            #name: ["Test SP", "en"]
-
-            #ui_info:
-            #  display_name:
-            #    - lang: en
-            #      text: "Display Name is the descriptive name of your service."
-            #  description:
-            #    - lang: en
-            #      text: "Description should be a short paragraph explaining the purpose of the service."
-            #  information_url:
-            #    - lang: en
-            #      text: "https://example.com/terms-of-service"
-            #  privacy_statement_url:
-            #    - lang: en
-            #      text: "https://example.com/privacy-policy"
-            #  keywords:
-            #    - lang: en
-            #      text: ["Matrix", "Element"]
-            #  logo:
-            #    - lang: en
-            #      text: "https://example.com/logo.svg"
-            #      width: "200"
-            #      height: "80"
-
-            #organization:
-            #  name: Example com
-            #  display_name:
-            #    - ["Example co", "en"]
-            #  url: "http://example.com"
-
-            #contact_person:
-            #  - given_name: Bob
-            #    sur_name: "the Sysadmin"
-            #    email_address": ["admin@example.com"]
-            #    contact_type": technical
-
-          # Instead of putting the config inline as above, you can specify a
-          # separate pysaml2 configuration file:
-          #
-          #config_path: "%(config_dir_path)s/sp_conf.py"
-
-          # The lifetime of a SAML session. This defines how long a user has to
-          # complete the authentication process, if allow_unsolicited is unset.
-          # The default is 15 minutes.
-          #
-          #saml_session_lifetime: 5m
-
-          # An external module can be provided here as a custom solution to
-          # mapping attributes returned from a saml provider onto a matrix user.
-          #
-          user_mapping_provider:
-            # The custom module's class. Uncomment to use a custom module.
-            #
-            #module: mapping_provider.SamlMappingProvider
-
-            # Custom configuration values for the module. Below options are
-            # intended for the built-in provider, they should be changed if
-            # using a custom module. This section will be passed as a Python
-            # dictionary to the module's `parse_config` method.
-            #
-            config:
-              # The SAML attribute (after mapping via the attribute maps) to use
-              # to derive the Matrix ID from. 'uid' by default.
-              #
-              # Note: This used to be configured by the
-              # saml2_config.mxid_source_attribute option. If that is still
-              # defined, its value will be used instead.
-              #
-              #mxid_source_attribute: displayName
-
-              # The mapping system to use for mapping the saml attribute onto a
-              # matrix ID.
-              #
-              # Options include:
-              #  * 'hexencode' (which maps unpermitted characters to '=xx')
-              #  * 'dotreplace' (which replaces unpermitted characters with
-              #     '.').
-              # The default is 'hexencode'.
-              #
-              # Note: This used to be configured by the
-              # saml2_config.mxid_mapping option. If that is still defined, its
-              # value will be used instead.
-              #
-              #mxid_mapping: dotreplace
-
-          # In previous versions of synapse, the mapping from SAML attribute to
-          # MXID was always calculated dynamically rather than stored in a
-          # table. For backwards- compatibility, we will look for user_ids
-          # matching such a pattern before creating a new account.
-          #
-          # This setting controls the SAML attribute which will be used for this
-          # backwards-compatibility lookup. Typically it should be 'uid', but if
-          # the attribute maps are changed, it may be necessary to change it.
-          #
-          # The default is 'uid'.
-          #
-          #grandfathered_mxid_source_attribute: upn
-
-          # It is possible to configure Synapse to only allow logins if SAML attributes
-          # match particular values. The requirements can be listed under
-          # `attribute_requirements` as shown below. All of the listed attributes must
-          # match for the login to be permitted.
-          #
-          #attribute_requirements:
-          #  - attribute: userGroup
-          #    value: "staff"
-          #  - attribute: department
-          #    value: "sales"
-
-          # If the metadata XML contains multiple IdP entities then the `idp_entityid`
-          # option must be set to the entity to redirect users to.
-          #
-          # Most deployments only have a single IdP entity and so should omit this
-          # option.
-          #
-          #idp_entityid: 'https://our_idp/entityid'
-        """ % {
-            "config_dir_path": config_dir_path
-        }
-
 
 ATTRIBUTE_REQUIREMENTS_SCHEMA = {
     "type": "array",
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 657322cb1f..085fe22c51 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -16,7 +16,6 @@ import argparse
 import itertools
 import logging
 import os.path
-import re
 import urllib.parse
 from textwrap import indent
 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
@@ -302,6 +301,26 @@ class ServerConfig(Config):
                 "public_baseurl cannot contain query parameters or a #-fragment"
             )
 
+        self.extra_well_known_client_content = config.get(
+            "extra_well_known_client_content", {}
+        )
+
+        if not isinstance(self.extra_well_known_client_content, dict):
+            raise ConfigError(
+                "extra_well_known_content must be a dictionary of key-value pairs"
+            )
+
+        if "m.homeserver" in self.extra_well_known_client_content:
+            raise ConfigError(
+                "m.homeserver is not supported in extra_well_known_content, "
+                "use public_baseurl in base config instead."
+            )
+        if "m.identity_server" in self.extra_well_known_client_content:
+            raise ConfigError(
+                "m.identity_server is not supported in extra_well_known_content, "
+                "use default_identity_server in base config instead."
+            )
+
         # Whether to enable user presence.
         presence_config = config.get("presence") or {}
         self.use_presence = presence_config.get("enabled")
@@ -702,9 +721,6 @@ class ServerConfig(Config):
         listeners: Optional[List[dict]],
         **kwargs: Any,
     ) -> str:
-        ip_range_blacklist = "\n".join(
-            "        #  - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
-        )
 
         _, bind_port = parse_and_validate_server_name(server_name)
         if bind_port is not None:
@@ -715,9 +731,6 @@ class ServerConfig(Config):
 
         pid_file = os.path.join(data_dir_path, "homeserver.pid")
 
-        # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
-        # default config string
-        default_room_version = DEFAULT_ROOM_VERSION
         secure_listeners = []
         unsecure_listeners = []
         private_addresses = ["::1", "127.0.0.1"]
@@ -765,501 +778,18 @@ class ServerConfig(Config):
                 compress: false"""
 
             if listeners:
-                # comment out this block
-                unsecure_http_bindings = "#" + re.sub(
-                    "\n {10}",
-                    lambda match: match.group(0) + "#",
-                    unsecure_http_bindings,
-                )
+                unsecure_http_bindings = ""
 
         if not secure_listeners:
-            secure_http_bindings = (
-                """#- port: %(bind_port)s
-          #  type: http
-          #  tls: true
-          #  resources:
-          #    - names: [client, federation]"""
-                % locals()
-            )
+            secure_http_bindings = ""
 
         return (
             """\
-        ## Server ##
-
-        # The public-facing domain of the server
-        #
-        # The server_name name will appear at the end of usernames and room addresses
-        # created on this server. For example if the server_name was example.com,
-        # usernames on this server would be in the format @user:example.com
-        #
-        # In most cases you should avoid using a matrix specific subdomain such as
-        # matrix.example.com or synapse.example.com as the server_name for the same
-        # reasons you wouldn't use user@email.example.com as your email address.
-        # See https://matrix-org.github.io/synapse/latest/delegate.html
-        # for information on how to host Synapse on a subdomain while preserving
-        # a clean server_name.
-        #
-        # The server_name cannot be changed later so it is important to
-        # configure this correctly before you start Synapse. It should be all
-        # lowercase and may contain an explicit port.
-        # Examples: matrix.org, localhost:8080
-        #
         server_name: "%(server_name)s"
-
-        # When running as a daemon, the file to store the pid in
-        #
         pid_file: %(pid_file)s
-
-        # The absolute URL to the web client which / will redirect to.
-        #
-        #web_client_location: https://riot.example.com/
-
-        # The public-facing base URL that clients use to access this Homeserver (not
-        # including _matrix/...). This is the same URL a user might enter into the
-        # 'Custom Homeserver URL' field on their client. If you use Synapse with a
-        # reverse proxy, this should be the URL to reach Synapse via the proxy.
-        # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
-        # 'listeners' below).
-        #
-        # Defaults to 'https://<server_name>/'.
-        #
-        #public_baseurl: https://example.com/
-
-        # Uncomment the following to tell other servers to send federation traffic on
-        # port 443.
-        #
-        # By default, other servers will try to reach our server on port 8448, which can
-        # be inconvenient in some environments.
-        #
-        # Provided 'https://<server_name>/' on port 443 is routed to Synapse, this
-        # option configures Synapse to serve a file at
-        # 'https://<server_name>/.well-known/matrix/server'. This will tell other
-        # servers to send traffic to port 443 instead.
-        #
-        # See https://matrix-org.github.io/synapse/latest/delegate.html for more
-        # information.
-        #
-        # Defaults to 'false'.
-        #
-        #serve_server_wellknown: true
-
-        # Set the soft limit on the number of file descriptors synapse can use
-        # Zero is used to indicate synapse should set the soft limit to the
-        # hard limit.
-        #
-        #soft_file_limit: 0
-
-        # Presence tracking allows users to see the state (e.g online/offline)
-        # of other local and remote users.
-        #
-        presence:
-          # Uncomment to disable presence tracking on this homeserver. This option
-          # replaces the previous top-level 'use_presence' option.
-          #
-          #enabled: false
-
-        # Whether to require authentication to retrieve profile data (avatars,
-        # display names) of other users through the client API. Defaults to
-        # 'false'. Note that profile data is also available via the federation
-        # API, unless allow_profile_lookup_over_federation is set to false.
-        #
-        #require_auth_for_profile_requests: true
-
-        # Uncomment to require a user to share a room with another user in order
-        # to retrieve their profile information. Only checked on Client-Server
-        # requests. Profile requests from other servers should be checked by the
-        # requesting server. Defaults to 'false'.
-        #
-        #limit_profile_requests_to_users_who_share_rooms: true
-
-        # Uncomment to prevent a user's profile data from being retrieved and
-        # displayed in a room until they have joined it. By default, a user's
-        # profile data is included in an invite event, regardless of the values
-        # of the above two settings, and whether or not the users share a server.
-        # Defaults to 'true'.
-        #
-        #include_profile_data_on_invite: false
-
-        # If set to 'true', removes the need for authentication to access the server's
-        # public rooms directory through the client API, meaning that anyone can
-        # query the room directory. Defaults to 'false'.
-        #
-        #allow_public_rooms_without_auth: true
-
-        # If set to 'true', allows any other homeserver to fetch the server's public
-        # rooms directory via federation. Defaults to 'false'.
-        #
-        #allow_public_rooms_over_federation: true
-
-        # The default room version for newly created rooms.
-        #
-        # Known room versions are listed here:
-        # https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions
-        #
-        # For example, for room version 1, default_room_version should be set
-        # to "1".
-        #
-        #default_room_version: "%(default_room_version)s"
-
-        # The GC threshold parameters to pass to `gc.set_threshold`, if defined
-        #
-        #gc_thresholds: [700, 10, 10]
-
-        # The minimum time in seconds between each GC for a generation, regardless of
-        # the GC thresholds. This ensures that we don't do GC too frequently.
-        #
-        # A value of `[1s, 10s, 30s]` indicates that a second must pass between consecutive
-        # generation 0 GCs, etc.
-        #
-        # Defaults to `[1s, 10s, 30s]`.
-        #
-        #gc_min_interval: [0.5s, 30s, 1m]
-
-        # Set the limit on the returned events in the timeline in the get
-        # and sync operations. The default value is 100. -1 means no upper limit.
-        #
-        # Uncomment the following to increase the limit to 5000.
-        #
-        #filter_timeline_limit: 5000
-
-        # Whether room invites to users on this server should be blocked
-        # (except those sent by local server admins). The default is False.
-        #
-        #block_non_admin_invites: true
-
-        # Room searching
-        #
-        # If disabled, new messages will not be indexed for searching and users
-        # will receive errors when searching for messages. Defaults to enabled.
-        #
-        #enable_search: false
-
-        # Prevent outgoing requests from being sent to the following blacklisted IP address
-        # CIDR ranges. If this option is not specified then it defaults to private IP
-        # address ranges (see the example below).
-        #
-        # The blacklist applies to the outbound requests for federation, identity servers,
-        # push servers, and for checking key validity for third-party invite events.
-        #
-        # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-        # listed here, since they correspond to unroutable addresses.)
-        #
-        # This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
-        #
-        # Note: The value is ignored when an HTTP proxy is in use
-        #
-        #ip_range_blacklist:
-%(ip_range_blacklist)s
-
-        # List of IP address CIDR ranges that should be allowed for federation,
-        # identity servers, push servers, and for checking key validity for
-        # third-party invite events. This is useful for specifying exceptions to
-        # wide-ranging blacklisted target IP ranges - e.g. for communication with
-        # a push server only visible in your network.
-        #
-        # This whitelist overrides ip_range_blacklist and defaults to an empty
-        # list.
-        #
-        #ip_range_whitelist:
-        #   - '192.168.1.1'
-
-        # List of ports that Synapse should listen on, their purpose and their
-        # configuration.
-        #
-        # Options for each listener include:
-        #
-        #   port: the TCP port to bind to
-        #
-        #   bind_addresses: a list of local addresses to listen on. The default is
-        #       'all local interfaces'.
-        #
-        #   type: the type of listener. Normally 'http', but other valid options are:
-        #       'manhole' (see https://matrix-org.github.io/synapse/latest/manhole.html),
-        #       'metrics' (see https://matrix-org.github.io/synapse/latest/metrics-howto.html),
-        #       'replication' (see https://matrix-org.github.io/synapse/latest/workers.html).
-        #
-        #   tls: set to true to enable TLS for this listener. Will use the TLS
-        #       key/cert specified in tls_private_key_path / tls_certificate_path.
-        #
-        #   x_forwarded: Only valid for an 'http' listener. Set to true to use the
-        #       X-Forwarded-For header as the client IP. Useful when Synapse is
-        #       behind a reverse-proxy.
-        #
-        #   resources: Only valid for an 'http' listener. A list of resources to host
-        #       on this port. Options for each resource are:
-        #
-        #       names: a list of names of HTTP resources. See below for a list of
-        #           valid resource names.
-        #
-        #       compress: set to true to enable HTTP compression for this resource.
-        #
-        #   additional_resources: Only valid for an 'http' listener. A map of
-        #        additional endpoints which should be loaded via dynamic modules.
-        #
-        # Valid resource names are:
-        #
-        #   client: the client-server API (/_matrix/client), and the synapse admin
-        #       API (/_synapse/admin). Also implies 'media' and 'static'.
-        #
-        #   consent: user consent forms (/_matrix/consent).
-        #       See https://matrix-org.github.io/synapse/latest/consent_tracking.html.
-        #
-        #   federation: the server-server API (/_matrix/federation). Also implies
-        #       'media', 'keys', 'openid'
-        #
-        #   keys: the key discovery API (/_matrix/key).
-        #
-        #   media: the media API (/_matrix/media).
-        #
-        #   metrics: the metrics interface.
-        #       See https://matrix-org.github.io/synapse/latest/metrics-howto.html.
-        #
-        #   openid: OpenID authentication.
-        #
-        #   replication: the HTTP replication API (/_synapse/replication).
-        #       See https://matrix-org.github.io/synapse/latest/workers.html.
-        #
-        #   static: static resources under synapse/static (/_matrix/static). (Mostly
-        #       useful for 'fallback authentication'.)
-        #
         listeners:
-          # TLS-enabled listener: for when matrix traffic is sent directly to synapse.
-          #
-          # Disabled by default. To enable it, uncomment the following. (Note that you
-          # will also need to give Synapse a TLS key and certificate: see the TLS section
-          # below.)
-          #
           %(secure_http_bindings)s
-
-          # Unsecure HTTP listener: for when matrix traffic passes through a reverse proxy
-          # that unwraps TLS.
-          #
-          # If you plan to use a reverse proxy, please see
-          # https://matrix-org.github.io/synapse/latest/reverse_proxy.html.
-          #
           %(unsecure_http_bindings)s
-
-            # example additional_resources:
-            #
-            #additional_resources:
-            #  "/_matrix/my/custom/endpoint":
-            #    module: my_module.CustomRequestHandler
-            #    config: {}
-
-          # Turn on the twisted ssh manhole service on localhost on the given
-          # port.
-          #
-          #- port: 9000
-          #  bind_addresses: ['::1', '127.0.0.1']
-          #  type: manhole
-
-        # Connection settings for the manhole
-        #
-        manhole_settings:
-          # The username for the manhole. This defaults to 'matrix'.
-          #
-          #username: manhole
-
-          # The password for the manhole. This defaults to 'rabbithole'.
-          #
-          #password: mypassword
-
-          # The private and public SSH key pair used to encrypt the manhole traffic.
-          # If these are left unset, then hardcoded and non-secret keys are used,
-          # which could allow traffic to be intercepted if sent over a public network.
-          #
-          #ssh_priv_key_path: %(config_dir_path)s/id_rsa
-          #ssh_pub_key_path: %(config_dir_path)s/id_rsa.pub
-
-        # Forward extremities can build up in a room due to networking delays between
-        # homeservers. Once this happens in a large room, calculation of the state of
-        # that room can become quite expensive. To mitigate this, once the number of
-        # forward extremities reaches a given threshold, Synapse will send an
-        # org.matrix.dummy_event event, which will reduce the forward extremities
-        # in the room.
-        #
-        # This setting defines the threshold (i.e. number of forward extremities in the
-        # room) at which dummy events are sent. The default value is 10.
-        #
-        #dummy_events_threshold: 5
-
-
-        ## Homeserver blocking ##
-
-        # How to reach the server admin, used in ResourceLimitError
-        #
-        #admin_contact: 'mailto:admin@server.com'
-
-        # Global blocking
-        #
-        #hs_disabled: false
-        #hs_disabled_message: 'Human readable reason for why the HS is blocked'
-
-        # Monthly Active User Blocking
-        #
-        # Used in cases where the admin or server owner wants to limit to the
-        # number of monthly active users.
-        #
-        # 'limit_usage_by_mau' disables/enables monthly active user blocking. When
-        # enabled and a limit is reached the server returns a 'ResourceLimitError'
-        # with error type Codes.RESOURCE_LIMIT_EXCEEDED
-        #
-        # 'max_mau_value' is the hard limit of monthly active users above which
-        # the server will start blocking user actions.
-        #
-        # 'mau_trial_days' is a means to add a grace period for active users. It
-        # means that users must be active for this number of days before they
-        # can be considered active and guards against the case where lots of users
-        # sign up in a short space of time never to return after their initial
-        # session.
-        #
-        # The option `mau_appservice_trial_days` is similar to `mau_trial_days`, but
-        # applies a different trial number if the user was registered by an appservice.
-        # A value of 0 means no trial days are applied. Appservices not listed in this
-        # dictionary use the value of `mau_trial_days` instead.
-        #
-        # 'mau_limit_alerting' is a means of limiting client side alerting
-        # should the mau limit be reached. This is useful for small instances
-        # where the admin has 5 mau seats (say) for 5 specific people and no
-        # interest increasing the mau limit further. Defaults to True, which
-        # means that alerting is enabled
-        #
-        #limit_usage_by_mau: false
-        #max_mau_value: 50
-        #mau_trial_days: 2
-        #mau_limit_alerting: false
-        #mau_appservice_trial_days:
-        #  "appservice-id": 1
-
-        # If enabled, the metrics for the number of monthly active users will
-        # be populated, however no one will be limited. If limit_usage_by_mau
-        # is true, this is implied to be true.
-        #
-        #mau_stats_only: false
-
-        # Sometimes the server admin will want to ensure certain accounts are
-        # never blocked by mau checking. These accounts are specified here.
-        #
-        #mau_limit_reserved_threepids:
-        #  - medium: 'email'
-        #    address: 'reserved_user@example.com'
-
-        # Used by phonehome stats to group together related servers.
-        #server_context: context
-
-        # Resource-constrained homeserver settings
-        #
-        # When this is enabled, the room "complexity" will be checked before a user
-        # joins a new remote room. If it is above the complexity limit, the server will
-        # disallow joining, or will instantly leave.
-        #
-        # Room complexity is an arbitrary measure based on factors such as the number of
-        # users in the room.
-        #
-        limit_remote_rooms:
-          # Uncomment to enable room complexity checking.
-          #
-          #enabled: true
-
-          # the limit above which rooms cannot be joined. The default is 1.0.
-          #
-          #complexity: 0.5
-
-          # override the error which is returned when the room is too complex.
-          #
-          #complexity_error: "This room is too complex."
-
-          # allow server admins to join complex rooms. Default is false.
-          #
-          #admins_can_join: true
-
-        # Whether to require a user to be in the room to add an alias to it.
-        # Defaults to 'true'.
-        #
-        #require_membership_for_aliases: false
-
-        # Whether to allow per-room membership profiles through the send of membership
-        # events with profile information that differ from the target's global profile.
-        # Defaults to 'true'.
-        #
-        #allow_per_room_profiles: false
-
-        # The largest allowed file size for a user avatar. Defaults to no restriction.
-        #
-        # Note that user avatar changes will not work if this is set without
-        # using Synapse's media repository.
-        #
-        #max_avatar_size: 10M
-
-        # The MIME types allowed for user avatars. Defaults to no restriction.
-        #
-        # Note that user avatar changes will not work if this is set without
-        # using Synapse's media repository.
-        #
-        #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
-
-        # How long to keep redacted events in unredacted form in the database. After
-        # this period redacted events get replaced with their redacted form in the DB.
-        #
-        # Defaults to `7d`. Set to `null` to disable.
-        #
-        #redaction_retention_period: 28d
-
-        # How long to track users' last seen time and IPs in the database.
-        #
-        # Defaults to `28d`. Set to `null` to disable clearing out of old rows.
-        #
-        #user_ips_max_age: 14d
-
-        # Inhibits the /requestToken endpoints from returning an error that might leak
-        # information about whether an e-mail address is in use or not on this
-        # homeserver.
-        # Note that for some endpoints the error situation is the e-mail already being
-        # used, and for others the error is entering the e-mail being unused.
-        # If this option is enabled, instead of returning an error, these endpoints will
-        # act as if no error happened and return a fake session ID ('sid') to clients.
-        #
-        #request_token_inhibit_3pid_errors: true
-
-        # A list of domains that the domain portion of 'next_link' parameters
-        # must match.
-        #
-        # This parameter is optionally provided by clients while requesting
-        # validation of an email or phone number, and maps to a link that
-        # users will be automatically redirected to after validation
-        # succeeds. Clients can make use this parameter to aid the validation
-        # process.
-        #
-        # The whitelist is applied whether the homeserver or an
-        # identity server is handling validation.
-        #
-        # The default value is no whitelist functionality; all domains are
-        # allowed. Setting this value to an empty list will instead disallow
-        # all domains.
-        #
-        #next_link_domain_whitelist: ["matrix.org"]
-
-        # Templates to use when generating email or HTML page contents.
-        #
-        templates:
-          # Directory in which Synapse will try to find template files to use to generate
-          # email or HTML page contents.
-          # If not set, or a file is not found within the template directory, a default
-          # template from within the Synapse package will be used.
-          #
-          # See https://matrix-org.github.io/synapse/latest/templates.html for more
-          # information about using custom templates.
-          #
-          #custom_template_directory: /path/to/custom/templates/
-
-        # List of rooms to exclude from sync responses. This is useful for server
-        # administrators wishing to group users into a room without these users being able
-        # to see it from their client.
-        #
-        # By default, no room is excluded.
-        #
-        #exclude_rooms_from_sync:
-        #    - !foo:example.com
         """
             % locals()
         )
diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py
index 505b4f6c6c..ce041abe9b 100644
--- a/synapse/config/server_notices.py
+++ b/synapse/config/server_notices.py
@@ -18,27 +18,6 @@ from synapse.types import JsonDict, UserID
 
 from ._base import Config
 
-DEFAULT_CONFIG = """\
-# Server Notices room configuration
-#
-# Uncomment this section to enable a room which can be used to send notices
-# from the server to users. It is a special room which cannot be left; notices
-# come from a special "notices" user id.
-#
-# If you uncomment this section, you *must* define the system_mxid_localpart
-# setting, which defines the id of the user which will be used to send the
-# notices.
-#
-# It's also possible to override the room name, the display name of the
-# "notices" user, and the avatar for the user.
-#
-#server_notices:
-#  system_mxid_localpart: notices
-#  system_mxid_display_name: "Server Notices"
-#  system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
-#  room_name: "Server Notices"
-"""
-
 
 class ServerNoticesConfig(Config):
     """Configuration for the server notices room.
@@ -83,6 +62,3 @@ class ServerNoticesConfig(Config):
         self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
         # todo: i18n
         self.server_notices_room_name = c.get("room_name", "Server Notices")
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return DEFAULT_CONFIG
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index f88eba77d0..a452cc3a49 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -26,7 +26,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'sso' section. Support for this setting has been deprecated and will be removed in a
 future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
@@ -107,43 +107,3 @@ class SSOConfig(Config):
             self.root.server.public_baseurl + "_matrix/static/client/login"
         )
         self.sso_client_whitelist.append(login_fallback_url)
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        # Additional settings to use with single-sign on systems such as OpenID Connect,
-        # SAML2 and CAS.
-        #
-        # Server admins can configure custom templates for pages related to SSO. See
-        # https://matrix-org.github.io/synapse/latest/templates.html for more information.
-        #
-        sso:
-            # A list of client URLs which are whitelisted so that the user does not
-            # have to confirm giving access to their account to the URL. Any client
-            # whose URL starts with an entry in the following list will not be subject
-            # to an additional confirmation step after the SSO login is completed.
-            #
-            # WARNING: An entry such as "https://my.client" is insecure, because it
-            # will also match "https://my.client.evil.site", exposing your users to
-            # phishing attacks from evil.site. To avoid this, include a slash after the
-            # hostname: "https://my.client/".
-            #
-            # The login fallback page (used by clients that don't natively support the
-            # required login flows) is whitelisted in addition to any URLs in this list.
-            #
-            # By default, this list contains only the login fallback page.
-            #
-            #client_whitelist:
-            #  - https://riot.im/develop
-            #  - https://my.custom.client/
-
-            # Uncomment to keep a user's profile fields in sync with information from
-            # the identity provider. Currently only syncing the displayname is
-            # supported. Fields are checked on every SSO login, and are updated
-            # if necessary.
-            #
-            # Note that enabling this option will override user profile information,
-            # regardless of whether users have opted-out of syncing that
-            # information when first signing in. Defaults to false.
-            #
-            #update_profile_information: true
-        """
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index ed1f416e4f..9621acd77f 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -46,16 +46,3 @@ class StatsConfig(Config):
             self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
         if not self.stats_enabled:
             logger.warning(ROOM_STATS_DISABLED_WARN)
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """
-        # Settings for local room and user statistics collection. See
-        # https://matrix-org.github.io/synapse/latest/room_and_user_statistics.html.
-        #
-        stats:
-          # Uncomment the following to disable room and user statistics. Note that doing
-          # so may cause certain features (such as the room directory) not to work
-          # correctly.
-          #
-          #enabled: false
-        """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index cb17950d25..336fe3e0da 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import logging
-import os
 from typing import Any, List, Optional, Pattern
 
 from matrix_common.regex import glob_to_regex
@@ -143,9 +142,6 @@ class TlsConfig(Config):
 
     def generate_config_section(
         self,
-        config_dir_path: str,
-        data_dir_path: str,
-        server_name: str,
         tls_certificate_path: Optional[str],
         tls_private_key_path: Optional[str],
         **kwargs: Any,
@@ -153,90 +149,18 @@ class TlsConfig(Config):
         """If the TLS paths are not specified the default will be certs in the
         config directory"""
 
-        base_key_name = os.path.join(config_dir_path, server_name)
-
         if bool(tls_certificate_path) != bool(tls_private_key_path):
             raise ConfigError(
                 "Please specify both a cert path and a key path or neither."
             )
 
-        tls_enabled = "" if tls_certificate_path and tls_private_key_path else "#"
-
-        if not tls_certificate_path:
-            tls_certificate_path = base_key_name + ".tls.crt"
-        if not tls_private_key_path:
-            tls_private_key_path = base_key_name + ".tls.key"
-
-        # flake8 doesn't recognise that variables are used in the below string
-        _ = tls_enabled
-
-        return (
-            """\
-        ## TLS ##
-
-        # PEM-encoded X509 certificate for TLS.
-        # This certificate, as of Synapse 1.0, will need to be a valid and verifiable
-        # certificate, signed by a recognised Certificate Authority.
-        #
-        # Be sure to use a `.pem` file that includes the full certificate chain including
-        # any intermediate certificates (for instance, if using certbot, use
-        # `fullchain.pem` as your certificate, not `cert.pem`).
-        #
-        %(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s"
-
-        # PEM-encoded private key for TLS
-        #
-        %(tls_enabled)stls_private_key_path: "%(tls_private_key_path)s"
-
-        # Whether to verify TLS server certificates for outbound federation requests.
-        #
-        # Defaults to `true`. To disable certificate verification, uncomment the
-        # following line.
-        #
-        #federation_verify_certificates: false
-
-        # The minimum TLS version that will be used for outbound federation requests.
-        #
-        # Defaults to `1`. Configurable to `1`, `1.1`, `1.2`, or `1.3`. Note
-        # that setting this value higher than `1.2` will prevent federation to most
-        # of the public Matrix network: only configure it to `1.3` if you have an
-        # entirely private federation setup and you can ensure TLS 1.3 support.
-        #
-        #federation_client_minimum_tls_version: 1.2
-
-        # Skip federation certificate verification on the following whitelist
-        # of domains.
-        #
-        # This setting should only be used in very specific cases, such as
-        # federation over Tor hidden services and similar. For private networks
-        # of homeservers, you likely want to use a private CA instead.
-        #
-        # Only effective if federation_verify_certicates is `true`.
-        #
-        #federation_certificate_verification_whitelist:
-        #  - lon.example.com
-        #  - "*.domain.com"
-        #  - "*.onion"
-
-        # List of custom certificate authorities for federation traffic.
-        #
-        # This setting should only normally be used within a private network of
-        # homeservers.
-        #
-        # Note that this list will replace those that are provided by your
-        # operating environment. Certificates must be in PEM format.
-        #
-        #federation_custom_ca_list:
-        #  - myCA1.pem
-        #  - myCA2.pem
-        #  - myCA3.pem
-        """
-            # Lowercase the string representation of boolean values
-            % {
-                x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1]
-                for x in locals().items()
-            }
-        )
+        if tls_certificate_path and tls_private_key_path:
+            return f"""\
+                tls_certificate_path: {tls_certificate_path}
+                tls_private_key_path: {tls_private_key_path}
+                """
+        else:
+            return ""
 
     def read_tls_certificate(self) -> crypto.X509:
         """Reads the TLS certificate from the configured file, and returns it
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index ae68a3dd1a..c19270c6c5 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -15,7 +15,7 @@
 from typing import Any, List, Set
 
 from synapse.types import JsonDict
-from synapse.util.check_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import check_requirements
 
 from ._base import Config, ConfigError
 
@@ -40,12 +40,7 @@ class TracerConfig(Config):
         if not self.opentracer_enabled:
             return
 
-        try:
-            check_requirements("opentracing")
-        except DependencyException as e:
-            raise ConfigError(
-                e.message  # noqa: B306, DependencyException.message is a property
-            )
+        check_requirements("opentracing")
 
         # The tracer is enabled so sanitize the config
 
@@ -67,53 +62,3 @@ class TracerConfig(Config):
                     ("opentracing", "force_tracing_for_users", f"index {i}"),
                 )
             self.force_tracing_for_users.add(u)
-
-    def generate_config_section(cls, **kwargs: Any) -> str:
-        return """\
-        ## Opentracing ##
-
-        # These settings enable opentracing, which implements distributed tracing.
-        # This allows you to observe the causal chains of events across servers
-        # including requests, key lookups etc., across any server running
-        # synapse or any other other services which supports opentracing
-        # (specifically those implemented with Jaeger).
-        #
-        opentracing:
-            # tracing is disabled by default. Uncomment the following line to enable it.
-            #
-            #enabled: true
-
-            # The list of homeservers we wish to send and receive span contexts and span baggage.
-            # See https://matrix-org.github.io/synapse/latest/opentracing.html.
-            #
-            # This is a list of regexes which are matched against the server_name of the
-            # homeserver.
-            #
-            # By default, it is empty, so no servers are matched.
-            #
-            #homeserver_whitelist:
-            #  - ".*"
-
-            # A list of the matrix IDs of users whose requests will always be traced,
-            # even if the tracing system would otherwise drop the traces due to
-            # probabilistic sampling.
-            #
-            # By default, the list is empty.
-            #
-            #force_tracing_for_users:
-            #  - "@user1:server_name"
-            #  - "@user2:server_name"
-
-            # Jaeger can be configured to sample traces at different rates.
-            # All configuration options provided by Jaeger can be set here.
-            # Jaeger's configuration is mostly related to trace sampling which
-            # is documented here:
-            # https://www.jaegertracing.io/docs/latest/sampling/.
-            #
-            #jaeger_config:
-            #  sampler:
-            #    type: const
-            #    param: 1
-            #  logging:
-            #    false
-        """
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 010e791924..c9e18b91e9 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -35,42 +35,3 @@ class UserDirectoryConfig(Config):
         self.user_directory_search_prefer_local_users = user_directory_config.get(
             "prefer_local_users", False
         )
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """
-        # User Directory configuration
-        #
-        user_directory:
-            # Defines whether users can search the user directory. If false then
-            # empty responses are returned to all queries. Defaults to true.
-            #
-            # Uncomment to disable the user directory.
-            #
-            #enabled: false
-
-            # Defines whether to search all users visible to your HS when searching
-            # the user directory. If false, search results will only contain users
-            # visible in public rooms and users sharing a room with the requester.
-            # Defaults to false.
-            #
-            # NB. If you set this to true, and the last time the user_directory search
-            # indexes were (re)built was before Synapse 1.44, you'll have to
-            # rebuild the indexes in order to search through all known users.
-            # These indexes are built the first time Synapse starts; admins can
-            # manually trigger a rebuild via API following the instructions at
-            #     https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/background_updates.html#run
-            #
-            # Uncomment to return search results containing all known users, even if that
-            # user does not share a room with the requester.
-            #
-            #search_all_users: true
-
-            # Defines whether to prefer local users in search query results.
-            # If True, local users are more likely to appear above remote users
-            # when searching the user directory. Defaults to false.
-            #
-            # Uncomment to prefer local over remote users in user directory search
-            # results.
-            #
-            #prefer_local_users: true
-        """
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 87c09abe24..43f0a0fa17 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -31,34 +31,3 @@ class VoipConfig(Config):
             config.get("turn_user_lifetime", "1h")
         )
         self.turn_allow_guests = config.get("turn_allow_guests", True)
-
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## TURN ##
-
-        # The public URIs of the TURN server to give to clients
-        #
-        #turn_uris: []
-
-        # The shared secret used to compute passwords for the TURN server
-        #
-        #turn_shared_secret: "YOUR_SHARED_SECRET"
-
-        # The Username and password if the TURN server needs them and
-        # does not use a token
-        #
-        #turn_username: "TURNSERVER_USERNAME"
-        #turn_password: "TURNSERVER_PASSWORD"
-
-        # How long generated TURN credentials last
-        #
-        #turn_user_lifetime: 1h
-
-        # Whether guests should be allowed to use the TURN server.
-        # This defaults to True, otherwise VoIP will be unreliable for guests.
-        # However, it does introduce a slight security risk as it allows users to
-        # connect to arbitrary endpoints without having first signed up for a
-        # valid account (e.g. by passing a CAPTCHA).
-        #
-        #turn_allow_guests: true
-        """
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index e1569b3c14..f2716422b5 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -410,55 +410,6 @@ class WorkerConfig(Config):
         # (By this point, these are either the same value or only one is not None.)
         return bool(new_option_should_run_here or legacy_option_should_run_here)
 
-    def generate_config_section(self, **kwargs: Any) -> str:
-        return """\
-        ## Workers ##
-
-        # Disables sending of outbound federation transactions on the main process.
-        # Uncomment if using a federation sender worker.
-        #
-        #send_federation: false
-
-        # It is possible to run multiple federation sender workers, in which case the
-        # work is balanced across them.
-        #
-        # This configuration must be shared between all federation sender workers, and if
-        # changed all federation sender workers must be stopped at the same time and then
-        # started, to ensure that all instances are running with the same config (otherwise
-        # events may be dropped).
-        #
-        #federation_sender_instances:
-        #  - federation_sender1
-
-        # When using workers this should be a map from `worker_name` to the
-        # HTTP replication listener of the worker, if configured.
-        #
-        #instance_map:
-        #  worker1:
-        #    host: localhost
-        #    port: 8034
-
-        # Experimental: When using workers you can define which workers should
-        # handle event persistence and typing notifications. Any worker
-        # specified here must also be in the `instance_map`.
-        #
-        #stream_writers:
-        #  events: worker1
-        #  typing: worker1
-
-        # The worker that is used to run background tasks (e.g. cleaning up expired
-        # data). If not provided this defaults to the main process.
-        #
-        #run_background_tasks_on: worker1
-
-        # A shared secret used by the replication APIs to authenticate HTTP requests
-        # from workers.
-        #
-        # By default this is unused and traffic is not authenticated.
-        #
-        #worker_replication_secret: ""
-        """
-
     def read_arguments(self, args: argparse.Namespace) -> None:
         # We support a bunch of command line arguments that override options in
         # the config. A lot of these options have a worker_* prefix when running
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 7520647d1e..23b799ac32 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import prune_event, prune_event_dict
+from synapse.logging.opentracing import trace
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ logger = logging.getLogger(__name__)
 Hasher = Callable[[bytes], "hashlib._Hash"]
 
 
+@trace
 def check_event_content_hash(
     event: EventBase, hash_algorithm: Hasher = hashlib.sha256
 ) -> bool:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4c0b587a76..389b0c5d53 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -15,11 +15,12 @@
 
 import logging
 import typing
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import SignatureVerifyException, verify_signed_json
+from typing_extensions import Protocol
 from unpaddedbase64 import decode_base64
 
 from synapse.api.constants import (
@@ -29,13 +30,20 @@ from synapse.api.constants import (
     JoinRules,
     Membership,
 )
-from synapse.api.errors import AuthError, EventSizeError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    EventSizeError,
+    SynapseError,
+    UnstableSpecAuthError,
+)
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.types import StateMap, UserID, get_domain_from_id
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
 
 if typing.TYPE_CHECKING:
     # conditional imports to avoid import cycle
@@ -45,9 +53,18 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def validate_event_for_room_version(
-    room_version_obj: RoomVersion, event: "EventBase"
-) -> None:
+class _EventSourceStore(Protocol):
+    async def get_events(
+        self,
+        event_ids: Collection[str],
+        redact_behaviour: EventRedactBehaviour,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+    ) -> Dict[str, "EventBase"]:
+        ...
+
+
+def validate_event_for_room_version(event: "EventBase") -> None:
     """Ensure that the event complies with the limits, and has the right signatures
 
     NB: does not *validate* the signatures - it assumes that any signatures present
@@ -60,12 +77,10 @@ def validate_event_for_room_version(
     NB: This is used to check events that have been received over federation. As such,
     it can only enforce the checks specified in the relevant room version, to avoid
     a split-brain situation where some servers accept such events, and others reject
-    them.
-
-    TODO: consider moving this into EventValidator
+    them. See also EventValidator, which contains extra checks which are applied only to
+    locally-generated events.
 
     Args:
-        room_version_obj: the version of the room which contains this event
         event: the event to be checked
 
     Raises:
@@ -103,7 +118,7 @@ def validate_event_for_room_version(
             raise AuthError(403, "Event not signed by sending server")
 
     is_invite_via_allow_rule = (
-        room_version_obj.msc3083_join_rules
+        event.room_version.msc3083_join_rules
         and event.type == EventTypes.Member
         and event.membership == Membership.JOIN
         and EventContentFields.AUTHORISING_USER in event.content
@@ -116,87 +131,127 @@ def validate_event_for_room_version(
             raise AuthError(403, "Event not signed by authorising server")
 
 
-def check_auth_rules_for_event(
-    room_version_obj: RoomVersion,
+async def check_state_independent_auth_rules(
+    store: _EventSourceStore,
     event: "EventBase",
-    auth_events: Iterable["EventBase"],
 ) -> None:
-    """Check that an event complies with the auth rules
+    """Check that an event complies with auth rules that are independent of room state
 
-    Checks whether an event passes the auth rules with a given set of state events
-
-    Assumes that we have already checked that the event is the right shape (it has
-    enough signatures, has a room ID, etc). In other words:
-
-     - it's fine for use in state resolution, when we have already decided whether to
-       accept the event or not, and are now trying to decide whether it should make it
-       into the room state
-
-     - when we're doing the initial event auth, it is only suitable in combination with
-       a bunch of other tests.
+    Runs through the first few auth rules, which are independent of room state. (Which
+    means that we only need to them once for each received event)
 
     Args:
-        room_version_obj: the version of the room
+        store: the datastore; used to fetch the auth events for validation
         event: the event being checked.
-        auth_events: the room state to check the events against.
 
     Raises:
         AuthError if the checks fail
     """
-    # We need to ensure that the auth events are actually for the same room, to
-    # stop people from using powers they've been granted in other rooms for
-    # example.
-    #
-    # Arguably we don't need to do this when we're just doing state res, as presumably
-    # the state res algorithm isn't silly enough to give us events from different rooms.
-    # Still, it's easier to do it anyway.
+    # Implementation of https://spec.matrix.org/v1.2/rooms/v9/#authorization-rules
+
+    # 1. If type is m.room.create:
+    if event.type == EventTypes.Create:
+        _check_create(event)
+
+        # 1.5 Otherwise, allow
+        return
+
+    # 2. Reject if event has auth_events that: ...
+    auth_events = await store.get_events(
+        event.auth_event_ids(),
+        redact_behaviour=EventRedactBehaviour.as_is,
+        allow_rejected=True,
+    )
     room_id = event.room_id
-    for auth_event in auth_events:
+    auth_dict: MutableStateMap[str] = {}
+    expected_auth_types = auth_types_for_event(event.room_version, event)
+    for auth_event_id in event.auth_event_ids():
+        auth_event = auth_events.get(auth_event_id)
+
+        # we should have all the auth events by now, so if we do not, that suggests
+        # a synapse programming error
+        if auth_event is None:
+            raise RuntimeError(
+                f"Event {event.event_id} has unknown auth event {auth_event_id}"
+            )
+
+        # We need to ensure that the auth events are actually for the same room, to
+        # stop people from using powers they've been granted in other rooms for
+        # example.
         if auth_event.room_id != room_id:
             raise AuthError(
                 403,
                 "During auth for event %s in room %s, found event %s in the state "
                 "which is in room %s"
-                % (event.event_id, room_id, auth_event.event_id, auth_event.room_id),
+                % (event.event_id, room_id, auth_event_id, auth_event.room_id),
             )
-        if auth_event.rejected_reason:
+
+        k = (auth_event.type, auth_event.state_key)
+
+        # 2.1 ... have duplicate entries for a given type and state_key pair
+        if k in auth_dict:
             raise AuthError(
                 403,
-                "During auth for event %s: found rejected event %s in the state"
-                % (event.event_id, auth_event.event_id),
+                f"Event {event.event_id} has duplicate auth_events for {k}: {auth_dict[k]} and {auth_event_id}",
             )
 
-    # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
-    #
-    # 1. If type is m.room.create:
-    if event.type == EventTypes.Create:
-        # 1b. If the domain of the room_id does not match the domain of the sender,
-        # reject.
-        sender_domain = get_domain_from_id(event.sender)
-        room_id_domain = get_domain_from_id(event.room_id)
-        if room_id_domain != sender_domain:
+        # 2.2 ... have entries whose type and state_key don’t match those specified by
+        #   the auth events selection algorithm described in the server
+        #   specification.
+        if k not in expected_auth_types:
             raise AuthError(
-                403, "Creation event's room_id domain does not match sender's"
+                403,
+                f"Event {event.event_id} has unexpected auth_event for {k}: {auth_event_id}",
             )
 
-        # 1c. If content.room_version is present and is not a recognised version, reject
-        room_version_prop = event.content.get("room_version", "1")
-        if room_version_prop not in KNOWN_ROOM_VERSIONS:
+        # We also need to check that the auth event itself is not rejected.
+        if auth_event.rejected_reason:
             raise AuthError(
                 403,
-                "room appears to have unsupported version %s" % (room_version_prop,),
+                "During auth for event %s: found rejected event %s in the state"
+                % (event.event_id, auth_event.event_id),
             )
 
-        logger.debug("Allowing! %s", event)
-        return
-
-    auth_dict = {(e.type, e.state_key): e for e in auth_events}
+        auth_dict[k] = auth_event_id
 
     # 3. If event does not have a m.room.create in its auth_events, reject.
     creation_event = auth_dict.get((EventTypes.Create, ""), None)
     if not creation_event:
         raise AuthError(403, "No create event in auth events")
 
+
+def check_state_dependent_auth_rules(
+    event: "EventBase",
+    auth_events: Iterable["EventBase"],
+) -> None:
+    """Check that an event complies with auth rules that depend on room state
+
+    Runs through the parts of the auth rules that check an event against bits of room
+    state.
+
+    Note:
+
+     - it's fine for use in state resolution, when we have already decided whether to
+       accept the event or not, and are now trying to decide whether it should make it
+       into the room state
+
+     - when we're doing the initial event auth, it is only suitable in combination with
+       a bunch of other tests (including, but not limited to, check_state_independent_auth_rules).
+
+    Args:
+        event: the event being checked.
+        auth_events: the room state to check the events against.
+
+    Raises:
+        AuthError if the checks fail
+    """
+    # there are no state-dependent auth rules for create events.
+    if event.type == EventTypes.Create:
+        logger.debug("Allowing! %s", event)
+        return
+
+    auth_dict = {(e.type, e.state_key): e for e in auth_events}
+
     # additional check for m.federate
     creating_domain = get_domain_from_id(event.room_id)
     originating_domain = get_domain_from_id(event.sender)
@@ -205,7 +260,10 @@ def check_auth_rules_for_event(
             raise AuthError(403, "This room has been marked as unfederatable.")
 
     # 4. If type is m.room.aliases
-    if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth:
+    if (
+        event.type == EventTypes.Aliases
+        and event.room_version.special_case_aliases_auth
+    ):
         # 4a. If event has no state_key, reject
         if not event.is_state():
             raise AuthError(403, "Alias event must be a state event")
@@ -225,7 +283,7 @@ def check_auth_rules_for_event(
 
     # 5. If type is m.room.membership
     if event.type == EventTypes.Member:
-        _is_membership_change_allowed(room_version_obj, event, auth_dict)
+        _is_membership_change_allowed(event.room_version, event, auth_dict)
         logger.debug("Allowing! %s", event)
         return
 
@@ -239,7 +297,11 @@ def check_auth_rules_for_event(
         invite_level = get_named_level(auth_dict, "invite", 0)
 
         if user_level < invite_level:
-            raise AuthError(403, "You don't have permission to invite users")
+            raise UnstableSpecAuthError(
+                403,
+                "You don't have permission to invite users",
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
         else:
             logger.debug("Allowing! %s", event)
             return
@@ -247,17 +309,17 @@ def check_auth_rules_for_event(
     _can_send_event(event, auth_dict)
 
     if event.type == EventTypes.PowerLevels:
-        _check_power_levels(room_version_obj, event, auth_dict)
+        _check_power_levels(event.room_version, event, auth_dict)
 
     if event.type == EventTypes.Redaction:
-        check_redaction(room_version_obj, event, auth_dict)
+        check_redaction(event.room_version, event, auth_dict)
 
     if (
         event.type == EventTypes.MSC2716_INSERTION
         or event.type == EventTypes.MSC2716_BATCH
         or event.type == EventTypes.MSC2716_MARKER
     ):
-        check_historical(room_version_obj, event, auth_dict)
+        check_historical(event.room_version, event, auth_dict)
 
     logger.debug("Allowing! %s", event)
 
@@ -277,6 +339,41 @@ def _check_size_limits(event: "EventBase") -> None:
         raise EventSizeError("event too large")
 
 
+def _check_create(event: "EventBase") -> None:
+    """Implementation of the auth rules for m.room.create events
+
+    Args:
+        event: The `m.room.create` event to be checked
+
+    Raises:
+        AuthError if the event does not pass the auth rules
+    """
+    assert event.type == EventTypes.Create
+
+    #  1.1 If it has any previous events, reject.
+    if event.prev_event_ids():
+        raise AuthError(403, "Create event has prev events")
+
+    # 1.2 If the domain of the room_id does not match the domain of the sender,
+    # reject.
+    sender_domain = get_domain_from_id(event.sender)
+    room_id_domain = get_domain_from_id(event.room_id)
+    if room_id_domain != sender_domain:
+        raise AuthError(403, "Creation event's room_id domain does not match sender's")
+
+    # 1.3 If content.room_version is present and is not a recognised version, reject
+    room_version_prop = event.content.get("room_version", "1")
+    if room_version_prop not in KNOWN_ROOM_VERSIONS:
+        raise AuthError(
+            403,
+            "room appears to have unsupported version %s" % (room_version_prop,),
+        )
+
+    # 1.4 If content has no creator field, reject.
+    if EventContentFields.ROOM_CREATOR not in event.content:
+        raise AuthError(403, "Create event lacks a 'creator' property")
+
+
 def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
     creation_event = auth_events.get((EventTypes.Create, ""))
     # There should always be a creation event, but if not don't federate.
@@ -387,7 +484,11 @@ def _is_membership_change_allowed(
             return
 
         if not caller_in_room:  # caller isn't joined
-            raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
+            raise UnstableSpecAuthError(
+                403,
+                "%s not in room %s." % (event.user_id, event.room_id),
+                errcode=Codes.NOT_JOINED,
+            )
 
     if Membership.INVITE == membership:
         # TODO (erikj): We should probably handle this more intelligently
@@ -397,10 +498,18 @@ def _is_membership_change_allowed(
         if target_banned:
             raise AuthError(403, "%s is banned from the room" % (target_user_id,))
         elif target_in_room:  # the target is already in the room.
-            raise AuthError(403, "%s is already in the room." % target_user_id)
+            raise UnstableSpecAuthError(
+                403,
+                "%s is already in the room." % target_user_id,
+                errcode=Codes.ALREADY_JOINED,
+            )
         else:
             if user_level < invite_level:
-                raise AuthError(403, "You don't have permission to invite users")
+                raise UnstableSpecAuthError(
+                    403,
+                    "You don't have permission to invite users",
+                    errcode=Codes.INSUFFICIENT_POWER,
+                )
     elif Membership.JOIN == membership:
         # Joins are valid iff caller == target and:
         # * They are not banned.
@@ -462,15 +571,27 @@ def _is_membership_change_allowed(
     elif Membership.LEAVE == membership:
         # TODO (erikj): Implement kicks.
         if target_banned and user_level < ban_level:
-            raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
+            raise UnstableSpecAuthError(
+                403,
+                "You cannot unban user %s." % (target_user_id,),
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
         elif target_user_id != event.user_id:
             kick_level = get_named_level(auth_events, "kick", 50)
 
             if user_level < kick_level or user_level <= target_level:
-                raise AuthError(403, "You cannot kick user %s." % target_user_id)
+                raise UnstableSpecAuthError(
+                    403,
+                    "You cannot kick user %s." % target_user_id,
+                    errcode=Codes.INSUFFICIENT_POWER,
+                )
     elif Membership.BAN == membership:
         if user_level < ban_level or user_level <= target_level:
-            raise AuthError(403, "You don't have permission to ban")
+            raise UnstableSpecAuthError(
+                403,
+                "You don't have permission to ban",
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
     elif room_version.msc2403_knocking and Membership.KNOCK == membership:
         if join_rule != JoinRules.KNOCK and (
             not room_version.msc3787_knock_restricted_join_rule
@@ -480,7 +601,11 @@ def _is_membership_change_allowed(
         elif target_user_id != event.user_id:
             raise AuthError(403, "You cannot knock for other users")
         elif target_in_room:
-            raise AuthError(403, "You cannot knock on a room you are already in")
+            raise UnstableSpecAuthError(
+                403,
+                "You cannot knock on a room you are already in",
+                errcode=Codes.ALREADY_JOINED,
+            )
         elif caller_invited:
             raise AuthError(403, "You are already invited to this room")
         elif target_banned:
@@ -551,10 +676,11 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
     user_level = get_user_power_level(event.user_id, auth_events)
 
     if user_level < send_level:
-        raise AuthError(
+        raise UnstableSpecAuthError(
             403,
             "You don't have permission to post that to the room. "
             + "user_level (%d) < send_level (%d)" % (user_level, send_level),
+            errcode=Codes.INSUFFICIENT_POWER,
         )
 
     # Check state_key
@@ -629,9 +755,10 @@ def check_historical(
     historical_level = get_named_level(auth_events, "historical", 100)
 
     if user_level < historical_level:
-        raise AuthError(
+        raise UnstableSpecAuthError(
             403,
             'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
+            errcode=Codes.INSUFFICIENT_POWER,
         )
 
 
@@ -653,6 +780,32 @@ def _check_power_levels(
         except Exception:
             raise SynapseError(400, "Not a valid power level: %s" % (v,))
 
+    # Reject events with stringy power levels if required by room version
+    if (
+        event.type == EventTypes.PowerLevels
+        and room_version_obj.msc3667_int_only_power_levels
+    ):
+        for k, v in event.content.items():
+            if k in {
+                "users_default",
+                "events_default",
+                "state_default",
+                "ban",
+                "redact",
+                "kick",
+                "invite",
+            }:
+                if not isinstance(v, int):
+                    raise SynapseError(400, f"{v!r} must be an integer.")
+            if k in {"events", "notifications", "users"}:
+                if not isinstance(v, dict) or not all(
+                    isinstance(v, int) for v in v.values()
+                ):
+                    raise SynapseError(
+                        400,
+                        f"{v!r} must be a dict wherein all the values are integers.",
+                    )
+
     key = (event.type, event.state_key)
     current_state = auth_events.get(key)
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..17f624b68f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -24,9 +24,11 @@ from synapse.api.room_versions import (
     RoomVersion,
 )
 from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
+from synapse.storage.state import StateFilter
 from synapse.types import EventID, JsonDict
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
@@ -120,8 +122,12 @@ class EventBuilder:
             The signed and hashed event.
         """
         if auth_event_ids is None:
-            state_ids = await self._state.get_current_state_ids(
-                self.room_id, prev_event_ids
+            state_ids = await self._state.compute_state_after_events(
+                self.room_id,
+                prev_event_ids,
+                state_filter=StateFilter.from_types(
+                    auth_types_for_event(self.room_version, self)
+                ),
             )
             auth_event_ids = self._event_auth_handler.compute_auth_events(
                 self, state_ids
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index b700cbbfa1..d3c8083e4a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,11 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
 from frozendict import frozendict
-from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
@@ -33,7 +32,7 @@ class EventContext:
     Holds information relevant to persisting an event
 
     Attributes:
-        rejected: A rejection reason if the event was rejected, else False
+        rejected: A rejection reason if the event was rejected, else None
 
         _state_group: The ID of the state group for this event. Note that state events
             are persisted with a state group which includes the new event, so this is
@@ -85,7 +84,7 @@ class EventContext:
     """
 
     _storage: "StorageControllers"
-    rejected: Union[Literal[False], str] = False
+    rejected: Optional[str] = None
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
     _state_delta_due_to_event: Optional[StateMap[str]] = None
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index d2e06c754e..623a2c71ea 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,18 +21,22 @@ from typing import (
     Awaitable,
     Callable,
     Collection,
-    Dict,
     List,
     Optional,
     Tuple,
     Union,
 )
 
+# `Literal` appears with Python 3.8.
+from typing_extensions import Literal
+
+import synapse
 from synapse.api.errors import Codes
+from synapse.logging.opentracing import trace
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias, UserProfile
+from synapse.types import JsonDict, RoomAlias, UserProfile
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
 from synapse.util.metrics import Measure
 
@@ -52,7 +56,7 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
             # disappear without warning depending on the results of ongoing
             # experiments.
             # Use this to return additional information as part of an error.
-            Tuple[Codes, Dict],
+            Tuple[Codes, JsonDict],
             # Deprecated
             bool,
         ]
@@ -62,12 +66,102 @@ SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[
     ["synapse.events.EventBase"],
     Awaitable[Union[bool, str]],
 ]
-USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
-USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]]
-USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]]
-USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
-USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
-USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
+USER_MAY_JOIN_ROOM_CALLBACK = Callable[
+    [str, str, bool],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
+USER_MAY_INVITE_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
+USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[
+    [str, str, str, str],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
+USER_MAY_CREATE_ROOM_CALLBACK = Callable[
+    [str],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
+USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[
+    [str, RoomAlias],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
+USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
+    [str, str],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
+]
 CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
 LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
     [
@@ -88,7 +182,19 @@ CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
 ]
 CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
     [ReadableFileWrapper, FileInfo],
-    Awaitable[bool],
+    Awaitable[
+        Union[
+            Literal["NOT_SPAM"],
+            Codes,
+            # Highly experimental, not officially part of the spamchecker API, may
+            # disappear without warning depending on the results of ongoing
+            # experiments.
+            # Use this to return additional information as part of an error.
+            Tuple[Codes, JsonDict],
+            # Deprecated
+            bool,
+        ]
+    ],
 ]
 
 
@@ -181,7 +287,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
 
 
 class SpamChecker:
-    NOT_SPAM = "NOT_SPAM"
+    NOT_SPAM: Literal["NOT_SPAM"] = "NOT_SPAM"
 
     def __init__(self, hs: "synapse.server.HomeServer") -> None:
         self.hs = hs
@@ -273,9 +379,10 @@ class SpamChecker:
         if check_media_file_for_spam is not None:
             self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
 
+    @trace
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
-    ) -> Union[Tuple[Codes, Dict], str]:
+    ) -> Union[Tuple[Codes, JsonDict], str]:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
@@ -306,7 +413,16 @@ class SpamChecker:
                 elif res is True:
                     # This spam-checker rejects the event with deprecated
                     # return value `True`
-                    return Codes.FORBIDDEN
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
                 elif not isinstance(res, str):
                     # mypy complains that we can't reach this code because of the
                     # return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know
@@ -352,7 +468,7 @@ class SpamChecker:
 
     async def user_may_join_room(
         self, user_id: str, room_id: str, is_invited: bool
-    ) -> bool:
+    ) -> Union[Tuple[Codes, JsonDict], Literal["NOT_SPAM"]]:
         """Checks if a given users is allowed to join a room.
         Not called when a user creates a room.
 
@@ -362,54 +478,84 @@ class SpamChecker:
             is_invited: Whether the user is invited into the room
 
         Returns:
-            Whether the user may join the room
+            NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
         """
         for callback in self._user_may_join_room_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_join_room = await delay_cancellation(
-                    callback(user_id, room_id, is_invited)
-                )
-            if may_join_room is False:
-                return False
+                res = await delay_cancellation(callback(user_id, room_id, is_invited))
+                # Normalize return values to `Codes` or `"NOT_SPAM"`.
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting join as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        # No spam-checker has rejected the request, let it pass.
+        return self.NOT_SPAM
 
     async def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
-    ) -> bool:
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a given user may send an invite
 
-        If this method returns false, the invite will be rejected.
-
         Args:
             inviter_userid: The user ID of the sender of the invitation
             invitee_userid: The user ID targeted in the invitation
             room_id: The room ID
 
         Returns:
-            True if the user may send an invite, otherwise False
+            NOT_SPAM if the operation is permitted, Codes otherwise.
         """
         for callback in self._user_may_invite_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_invite = await delay_cancellation(
+                res = await delay_cancellation(
                     callback(inviter_userid, invitee_userid, room_id)
                 )
-            if may_invite is False:
-                return False
+                # Normalize return values to `Codes` or `"NOT_SPAM"`.
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting invite as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        # No spam-checker has rejected the request, let it pass.
+        return self.NOT_SPAM
 
     async def user_may_send_3pid_invite(
         self, inviter_userid: str, medium: str, address: str, room_id: str
-    ) -> bool:
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a given user may invite a given threepid into the room
 
-        If this method returns false, the threepid invite will be rejected.
-
         Note that if the threepid is already associated with a Matrix user ID, Synapse
         will call user_may_invite with said user ID instead.
 
@@ -420,88 +566,141 @@ class SpamChecker:
             room_id: The room ID
 
         Returns:
-            True if the user may send the invite, otherwise False
+            NOT_SPAM if the operation is permitted, Codes otherwise.
         """
         for callback in self._user_may_send_3pid_invite_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_send_3pid_invite = await delay_cancellation(
+                res = await delay_cancellation(
                     callback(inviter_userid, medium, address, room_id)
                 )
-            if may_send_3pid_invite is False:
-                return False
+                # Normalize return values to `Codes` or `"NOT_SPAM"`.
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting 3pid invite as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        return self.NOT_SPAM
 
-    async def user_may_create_room(self, userid: str) -> bool:
+    async def user_may_create_room(
+        self, userid: str
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a given user may create a room
 
-        If this method returns false, the creation request will be rejected.
-
         Args:
             userid: The ID of the user attempting to create a room
-
-        Returns:
-            True if the user may create a room, otherwise False
         """
         for callback in self._user_may_create_room_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_create_room = await delay_cancellation(callback(userid))
-            if may_create_room is False:
-                return False
+                res = await delay_cancellation(callback(userid))
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting room creation as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        return self.NOT_SPAM
 
     async def user_may_create_room_alias(
         self, userid: str, room_alias: RoomAlias
-    ) -> bool:
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a given user may create a room alias
 
-        If this method returns false, the association request will be rejected.
-
         Args:
             userid: The ID of the user attempting to create a room alias
             room_alias: The alias to be created
 
-        Returns:
-            True if the user may create a room alias, otherwise False
         """
         for callback in self._user_may_create_room_alias_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_create_room_alias = await delay_cancellation(
-                    callback(userid, room_alias)
-                )
-            if may_create_room_alias is False:
-                return False
+                res = await delay_cancellation(callback(userid, room_alias))
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting room create as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        return self.NOT_SPAM
 
-    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+    async def user_may_publish_room(
+        self, userid: str, room_id: str
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a given user may publish a room to the directory
 
-        If this method returns false, the publish request will be rejected.
-
         Args:
             userid: The user ID attempting to publish the room
             room_id: The ID of the room that would be published
-
-        Returns:
-            True if the user may publish the room, otherwise False
         """
         for callback in self._user_may_publish_room_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                may_publish_room = await delay_cancellation(callback(userid, room_id))
-            if may_publish_room is False:
-                return False
+                res = await delay_cancellation(callback(userid, room_id))
+                if res is True or res is self.NOT_SPAM:
+                    continue
+                elif res is False:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting room publication as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return True
+        return self.NOT_SPAM
 
     async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
@@ -567,7 +766,7 @@ class SpamChecker:
 
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
-    ) -> bool:
+    ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
         """Checks if a piece of newly uploaded media should be blocked.
 
         This will be called for local uploads, downloads of remote media, each
@@ -580,31 +779,44 @@ class SpamChecker:
 
             async def check_media_file_for_spam(
                 self, file: ReadableFileWrapper, file_info: FileInfo
-            ) -> bool:
+            ) -> Union[Codes, Literal["NOT_SPAM"]]:
                 buffer = BytesIO()
                 await file.write_chunks_to(buffer.write)
 
                 if buffer.getvalue() == b"Hello World":
-                    return True
+                    return synapse.module_api.NOT_SPAM
 
-                return False
+                return Codes.FORBIDDEN
 
 
         Args:
             file: An object that allows reading the contents of the media.
             file_info: Metadata about the file.
-
-        Returns:
-            True if the media should be blocked or False if it should be
-            allowed.
         """
 
         for callback in self._check_media_file_for_spam_callbacks:
             with Measure(
                 self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
             ):
-                spam = await delay_cancellation(callback(file_wrapper, file_info))
-            if spam:
-                return True
+                res = await delay_cancellation(callback(file_wrapper, file_info))
+                # Normalize return values to `Codes` or `"NOT_SPAM"`.
+                if res is False or res is self.NOT_SPAM:
+                    continue
+                elif res is True:
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
+                elif isinstance(res, synapse.api.errors.Codes):
+                    return res, {}
+                elif (
+                    isinstance(res, tuple)
+                    and len(res) == 2
+                    and isinstance(res[0], synapse.api.errors.Codes)
+                    and isinstance(res[1], dict)
+                ):
+                    return res
+                else:
+                    logger.warning(
+                        "Module returned invalid value, rejecting media file as spam"
+                    )
+                    return synapse.api.errors.Codes.FORBIDDEN, {}
 
-        return False
+        return self.NOT_SPAM
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 35f3f3690f..72ab696898 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -464,14 +464,7 @@ class ThirdPartyEventRules:
         Returns:
             A dict mapping (event type, state key) to state event.
         """
-        state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
-        room_state_events = await self.store.get_events(state_ids.values())
-
-        state_events = {}
-        for key, event_id in state_ids.items():
-            state_events[key] = room_state_events[event_id]
-
-        return state_events
+        return await self._storage_controllers.state.get_current_state(room_id)
 
     async def on_profile_update(
         self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ac91c5eb57..71853caad8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -161,7 +161,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
     elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
         add_fields(EventContentFields.MSC2716_BATCH_ID)
     elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
-        add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
+        add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 29fa9b3880..27c8beba25 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -35,6 +35,10 @@ class EventValidator:
     def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
         """Validates the event has roughly the right format
 
+        Suitable for checking a locally-created event. It has stricter checks than
+        is appropriate for an event received over federation (for which, see
+        event_auth.validate_event_for_room_version)
+
         Args:
             event: The event to validate.
             config: The homeserver's configuration.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2522bf78fc..4269a98db2 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
+from synapse.logging.opentracing import log_kv, trace
 from synapse.types import JsonDict, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -55,6 +56,7 @@ class FederationBase:
         self._clock = hs.get_clock()
         self._storage_controllers = hs.get_storage_controllers()
 
+    @trace
     async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
     ) -> EventBase:
@@ -97,17 +99,36 @@ class FederationBase:
                     "Event %s seems to have been redacted; using our redacted copy",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event seems to have been redacted; using our redacted copy",
+                        "event_id": pdu.event_id,
+                    }
+                )
             else:
                 logger.warning(
                     "Event %s content has been tampered, redacting",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event content has been tampered, redacting",
+                        "event_id": pdu.event_id,
+                    }
+                )
             return redacted_event
 
         spam_check = await self.spam_checker.check_event_for_spam(pdu)
 
         if spam_check != self.spam_checker.NOT_SPAM:
             logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
+            log_kv(
+                {
+                    "message": "Event contains spam, redacting (to save disk space) "
+                    "as well as soft-failing (to stop using the event in prev_events)",
+                    "event_id": pdu.event_id,
+                }
+            )
             # we redact (to save disk space) as well as soft-failing (to stop
             # using the event in prev_events).
             redacted_event = prune_event(pdu)
@@ -117,6 +138,7 @@ class FederationBase:
         return pdu
 
 
+@trace
 async def _check_sigs_on_pdu(
     keyring: Keyring, room_version: RoomVersion, pdu: EventBase
 ) -> None:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ad475a913b..7ee2974bb1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -53,7 +53,7 @@ from synapse.api.room_versions import (
     RoomVersion,
     RoomVersions,
 )
-from synapse.events import EventBase, builder
+from synapse.events import EventBase, builder, make_event_from_dict
 from synapse.federation.federation_base import (
     FederationBase,
     InvalidEventSignatureError,
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
 )
 from synapse.federation.transport.client import SendJoinResponse
 from synapse.http.types import QueryParams
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -217,7 +218,7 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: int
+        self, destination: str, content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
@@ -233,6 +234,8 @@ class FederationClient(FederationBase):
             destination, content, timeout
         )
 
+    @trace
+    @tag_args
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> Optional[List[EventBase]]:
@@ -299,7 +302,8 @@ class FederationClient(FederationBase):
                 moving to the next destination. None indicates no timeout.
 
         Returns:
-            The requested PDU, or None if we were unable to find it.
+            A copy of the requested PDU that is safe to modify, or None if we
+            were unable to find it.
 
         Raises:
             SynapseError, NotRetryingDestination, FederationDeniedError
@@ -309,7 +313,7 @@ class FederationClient(FederationBase):
         )
 
         logger.debug(
-            "retrieved event id %s from %s: %r",
+            "get_pdu_from_destination_raw: retrieved event id %s from %s: %r",
             event_id,
             destination,
             transaction_data,
@@ -334,6 +338,8 @@ class FederationClient(FederationBase):
 
         return None
 
+    @trace
+    @tag_args
     async def get_pdu(
         self,
         destinations: Iterable[str],
@@ -358,55 +364,95 @@ class FederationClient(FederationBase):
             The requested PDU, or None if we were unable to find it.
         """
 
-        # TODO: Rate limit the number of times we try and get the same event.
+        logger.debug(
+            "get_pdu: event_id=%s from destinations=%s", event_id, destinations
+        )
 
-        ev = self._get_pdu_cache.get(event_id)
-        if ev:
-            return ev
+        # TODO: Rate limit the number of times we try and get the same event.
 
-        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+        # We might need the same event multiple times in quick succession (before
+        # it gets persisted to the database), so we cache the results of the lookup.
+        # Note that this is separate to the regular get_event cache which caches
+        # events once they have been persisted.
+        event = self._get_pdu_cache.get(event_id)
+
+        # If we don't see the event in the cache, go try to fetch it from the
+        # provided remote federated destinations
+        if not event:
+            pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
+            for destination in destinations:
+                now = self._clock.time_msec()
+                last_attempt = pdu_attempts.get(destination, 0)
+                if last_attempt + PDU_RETRY_TIME_MS > now:
+                    logger.debug(
+                        "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)",
+                        destination,
+                        last_attempt,
+                        PDU_RETRY_TIME_MS,
+                        now,
+                    )
+                    continue
+
+                try:
+                    event = await self.get_pdu_from_destination_raw(
+                        destination=destination,
+                        event_id=event_id,
+                        room_version=room_version,
+                        timeout=timeout,
+                    )
 
-        signed_pdu = None
-        for destination in destinations:
-            now = self._clock.time_msec()
-            last_attempt = pdu_attempts.get(destination, 0)
-            if last_attempt + PDU_RETRY_TIME_MS > now:
-                continue
+                    pdu_attempts[destination] = now
 
-            try:
-                signed_pdu = await self.get_pdu_from_destination_raw(
-                    destination=destination,
-                    event_id=event_id,
-                    room_version=room_version,
-                    timeout=timeout,
-                )
+                    if event:
+                        # Prime the cache
+                        self._get_pdu_cache[event.event_id] = event
 
-                pdu_attempts[destination] = now
+                        # Now that we have an event, we can break out of this
+                        # loop and stop asking other destinations.
+                        break
 
-            except SynapseError as e:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
-            except NotRetryingDestination as e:
-                logger.info(str(e))
-                continue
-            except FederationDeniedError as e:
-                logger.info(str(e))
-                continue
-            except Exception as e:
-                pdu_attempts[destination] = now
+                except SynapseError as e:
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
+                except NotRetryingDestination as e:
+                    logger.info(str(e))
+                    continue
+                except FederationDeniedError as e:
+                    logger.info(str(e))
+                    continue
+                except Exception as e:
+                    pdu_attempts[destination] = now
+
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
 
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
+        if not event:
+            return None
 
-        if signed_pdu:
-            self._get_pdu_cache[event_id] = signed_pdu
+        # `event` now refers to an object stored in `get_pdu_cache`. Our
+        # callers may need to modify the returned object (eg to set
+        # `event.internal_metadata.outlier = true`), so we return a copy
+        # rather than the original object.
+        event_copy = make_event_from_dict(
+            event.get_pdu_json(),
+            event.room_version,
+        )
 
-        return signed_pdu
+        return event_copy
 
+    @trace
+    @tag_args
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> Tuple[List[str], List[str]]:
@@ -426,6 +472,23 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids",
+            str(state_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids.length",
+            str(len(state_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids",
+            str(auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids.length",
+            str(len(auth_event_ids)),
+        )
+
         if not isinstance(state_event_ids, list) or not isinstance(
             auth_event_ids, list
         ):
@@ -433,6 +496,8 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    @trace
+    @tag_args
     async def get_room_state(
         self,
         destination: str,
@@ -492,6 +557,7 @@ class FederationClient(FederationBase):
 
         return valid_state_events, valid_auth_events
 
+    @trace
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
@@ -521,11 +587,15 @@ class FederationClient(FederationBase):
         Returns:
             A list of PDUs that have valid signatures and hashes.
         """
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "pdus.length",
+            str(len(pdus)),
+        )
 
         # We limit how many PDUs we check at once, as if we try to do hundreds
         # of thousands of PDUs at once we see large memory spikes.
 
-        valid_pdus = []
+        valid_pdus: List[EventBase] = []
 
         async def _execute(pdu: EventBase) -> None:
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
@@ -541,6 +611,8 @@ class FederationClient(FederationBase):
 
         return valid_pdus
 
+    @trace
+    @tag_args
     async def _check_sigs_and_hash_and_fetch_one(
         self,
         pdu: EventBase,
@@ -573,16 +645,27 @@ class FederationClient(FederationBase):
         except InvalidEventSignatureError as e:
             logger.warning(
                 "Signature on retrieved event %s was invalid (%s). "
-                "Checking local store/orgin server",
+                "Checking local store/origin server",
                 pdu.event_id,
                 e,
             )
+            log_kv(
+                {
+                    "message": "Signature on retrieved event was invalid. "
+                    "Checking local store/origin server",
+                    "event_id": pdu.event_id,
+                    "InvalidEventSignatureError": e,
+                }
+            )
 
         # Check local db.
         res = await self.store.get_event(
             pdu.event_id, allow_rejected=True, allow_none=True
         )
 
+        # If the PDU fails its signature check and we don't have it in our
+        # database, we then request it from sender's server (if that is not the
+        # same as `origin`).
         pdu_origin = get_domain_from_id(pdu.sender)
         if not res and pdu_origin != origin:
             try:
@@ -686,6 +769,12 @@ class FederationClient(FederationBase):
         if failover_errcodes is None:
             failover_errcodes = ()
 
+        if not destinations:
+            # Give a bit of a clearer message if no servers were specified at all.
+            raise SynapseError(
+                502, f"Failed to {description} via any server: No servers specified."
+            )
+
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -735,7 +824,7 @@ class FederationClient(FederationBase):
                     "Failed to %s via %s", description, destination, exc_info=True
                 )
 
-        raise SynapseError(502, "Failed to %s via any server" % (description,))
+        raise SynapseError(502, f"Failed to {description} via any server")
 
     async def make_membership_event(
         self,
@@ -1642,10 +1731,6 @@ def _validate_hierarchy_event(d: JsonDict) -> None:
     if not isinstance(event_type, str):
         raise ValueError("Invalid event: 'event_type' must be a str")
 
-    room_id = d.get("room_id")
-    if not isinstance(room_id, str):
-        raise ValueError("Invalid event: 'room_id' must be a str")
-
     state_key = d.get("state_key")
     if not isinstance(state_key, str):
         raise ValueError("Invalid event: 'state_key' must be a str")
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3e1518f1f6..3bf84cf625 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -61,12 +61,18 @@ from synapse.logging.context import (
     nested_logging_context,
     run_in_background,
 )
-from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.opentracing import (
+    log_kv,
+    start_active_span_from_edu,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
     ReplicationGetQueryRestServlet,
 )
+from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.lock import Lock
 from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_decoder, unwrapFirstError
@@ -117,6 +123,7 @@ class FederationServer(FederationBase):
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
+        self._room_member_handler = hs.get_room_member_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -467,7 +474,7 @@ class FederationServer(FederationBase):
                     )
                     for pdu in pdus_by_room[room_id]:
                         event_id = pdu.event_id
-                        pdu_results[event_id] = e.error_dict()
+                        pdu_results[event_id] = e.error_dict(self.hs.config)
                     return
 
                 for pdu in pdus_by_room[room_id]:
@@ -545,6 +552,8 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
+    @trace
+    @tag_args
     async def on_state_ids_request(
         self, origin: str, room_id: str, event_id: str
     ) -> Tuple[int, JsonDict]:
@@ -567,6 +576,8 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
+    @trace
+    @tag_args
     async def _on_state_ids_request_compute(
         self, room_id: str, event_id: str
     ) -> JsonDict:
@@ -620,6 +631,15 @@ class FederationServer(FederationBase):
             )
             raise IncompatibleRoomVersionError(room_version=room_version)
 
+        # Refuse the request if that room has seen too many joins recently.
+        # This is in addition to the HS-level rate limiting applied by
+        # BaseFederationServlet.
+        # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
         pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
         return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
@@ -654,6 +674,12 @@ class FederationServer(FederationBase):
         room_id: str,
         caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
+
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
@@ -737,6 +763,17 @@ class FederationServer(FederationBase):
             The partial knock event.
         """
         origin_host, _ = parse_server_name(origin)
+
+        if await self.store.is_partial_state_room(room_id):
+            # Before we do anything: check if the room is partial-stated.
+            # Note that at the time this check was added, `on_make_knock_request` would
+            # block due to https://github.com/matrix-org/synapse/issues/12997.
+            raise SynapseError(
+                404,
+                "Unable to handle /make_knock right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         await self.check_server_matches_acl(origin_host, room_id)
 
         room_version = await self.store.get_room_version(room_id)
@@ -826,8 +863,25 @@ class FederationServer(FederationBase):
                 Codes.BAD_JSON,
             )
 
+        # Note that get_room_version throws if the room does not exist here.
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /send_join, /send_knock or /send_leave.
+            # This is because we will not be able to provide the server list (for partial
+            # joins) or the full state (for full joins).
+            # Return a 404 as we would if we weren't in the room at all.
+            logger.info(
+                f"Rejecting /send_{membership_type} to %s because it's a partial state room",
+                room_id,
+            )
+            raise SynapseError(
+                404,
+                f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
             raise SynapseError(
                 403,
@@ -882,9 +936,20 @@ class FederationServer(FederationBase):
             logger.warning("%s", errmsg)
             raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
-        return await self._federation_event_handler.on_send_membership_event(
-            origin, event
-        )
+        try:
+            return await self._federation_event_handler.on_send_membership_event(
+                origin, event
+            )
+        except PartialStateConflictError:
+            # The room was un-partial stated while we were persisting the event.
+            # Try once more, with full state this time.
+            logger.info(
+                "Room %s was un-partial stated during `on_send_membership_event`, trying again.",
+                room_id,
+            )
+            return await self._federation_event_handler.on_send_membership_event(
+                origin, event
+            )
 
     async def on_event_auth(
         self, origin: str, room_id: str, event_id: str
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 99a794c042..8bc60e3e3e 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -351,7 +351,11 @@ class FederationSender(AbstractFederationSender):
             self._is_processing = True
             while True:
                 last_token = await self.store.get_federation_out_pos("events")
-                next_token, events = await self.store.get_all_new_events_stream(
+                (
+                    next_token,
+                    events,
+                    event_to_received_ts,
+                ) = await self.store.get_all_new_events_stream(
                     last_token, self._last_poked_id, limit=100
                 )
 
@@ -437,6 +441,19 @@ class FederationSender(AbstractFederationSender):
                             destinations = await self._external_cache.get(
                                 "get_joined_hosts", str(sg)
                             )
+                            if destinations is None:
+                                # Add logging to help track down #13444
+                                logger.info(
+                                    "Unexpectedly did not have cached destinations for %s / %s",
+                                    sg,
+                                    event.event_id,
+                                )
+                        else:
+                            # Add logging to help track down #13444
+                            logger.info(
+                                "Unexpectedly did not have cached prev group for %s",
+                                event.event_id,
+                            )
 
                     if destinations is None:
                         try:
@@ -476,7 +493,7 @@ class FederationSender(AbstractFederationSender):
                         await self._send_pdu(event, sharded_destinations)
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "federation_sender"
@@ -509,7 +526,7 @@ class FederationSender(AbstractFederationSender):
 
                 if events:
                     now = self.clock.time_msec()
-                    ts = await self.store.get_received_ts(events[-1].event_id)
+                    ts = event_to_received_ts[events[-1].event_id]
                     assert ts is not None
 
                     synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9e84bd677e..32074b8ca6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -619,7 +619,7 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys(
-        self, destination: str, query_content: JsonDict, timeout: int
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 84100a5a52..1db8009d6c 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
-from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
+from synapse.http.server import HttpServer, ServletCallback
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import run_in_background
@@ -34,6 +34,7 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.types import JsonDict
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import parse_and_validate_server_name
 
@@ -309,7 +310,7 @@ class BaseFederationServlet:
                 raise
 
             # update the active opentracing span with the authenticated entity
-            set_tag("authenticated_entity", origin)
+            set_tag("authenticated_entity", str(origin))
 
             # if the origin is authenticated and whitelisted, use its span context
             # as the parent.
@@ -375,7 +376,7 @@ class BaseFederationServlet:
             if code is None:
                 continue
 
-            if is_method_cancellable(code):
+            if is_function_cancellable(code):
                 # The wrapper added by `self._wrap` will inherit the cancellable flag,
                 # but the wrapper itself does not support cancellation yet.
                 # Once resolved, the cancellation tests in
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 814553e098..203b62e015 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -104,14 +104,15 @@ class ApplicationServicesHandler:
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                limit = 100
                 upper_bound = -1
                 while upper_bound < self.current_max:
+                    last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
                         events,
-                    ) = await self.store.get_new_events_for_appservice(
-                        self.current_max, limit
+                        event_to_received_ts,
+                    ) = await self.store.get_all_new_events_stream(
+                        last_token, self.current_max, limit=100, get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
@@ -150,7 +151,7 @@ class ApplicationServicesHandler:
                             )
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag_by_event.labels(
@@ -187,7 +188,7 @@ class ApplicationServicesHandler:
 
                     if events:
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(events[-1].event_id)
+                        ts = event_to_received_ts[events[-1].event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6e15028b0a..0327fc57a4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -37,9 +37,7 @@ from typing import (
 
 import attr
 import bcrypt
-import pymacaroons
 import unpaddedbase64
-from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
@@ -69,7 +67,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
+from synapse.util.macaroons import LoginTokenAttributes
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
@@ -180,25 +178,13 @@ class SsoLoginExtraAttributes:
     extra_attributes: JsonDict
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class LoginTokenAttributes:
-    """Data we store in a short-term login token"""
-
-    user_id: str
-
-    auth_provider_id: str
-    """The SSO Identity Provider that the user authenticated with, to get this token."""
-
-    auth_provider_session_id: Optional[str]
-    """The session ID advertised by the SSO Identity Provider."""
-
-
 class AuthHandler:
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -294,7 +280,7 @@ class AuthHandler:
         that it isn't stolen by re-authenticating them.
 
         Args:
-            requester: The user, as given by the access token
+            requester: The user making the request, according to the access token.
 
             request: The request sent by the client.
 
@@ -579,7 +565,7 @@ class AuthHandler:
             except LoginError as e:
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
-                errordict = e.error_dict()
+                errordict = e.error_dict(self.hs.config)
 
         creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
@@ -985,7 +971,7 @@ class AuthHandler:
             not is_appservice_ghost
             or self.hs.config.appservice.track_appservice_user_ips
         ):
-            await self.auth.check_auth_blocking(user_id)
+            await self.auth_blocking.check_auth_blocking(user_id)
 
         access_token = self.generate_access_token(target_user_id_obj)
         await self.store.add_access_token_to_user(
@@ -1439,7 +1425,7 @@ class AuthHandler:
         except Exception:
             raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(res.user_id)
+        await self.auth_blocking.check_auth_blocking(res.user_id)
         return res
 
     async def delete_access_token(self, access_token: str) -> None:
@@ -1449,20 +1435,25 @@ class AuthHandler:
             access_token: access token to be deleted
 
         """
-        user_info = await self.auth.get_user_by_access_token(access_token)
+        token = await self.store.get_user_by_access_token(access_token)
+        if not token:
+            # At this point, the token should already have been fetched once by
+            # the caller, so this should not happen, unless of a race condition
+            # between two delete requests
+            raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
         await self.store.delete_access_token(access_token)
 
         # see if any modules want to know about this
         await self.password_auth_provider.on_logged_out(
-            user_id=user_info.user_id,
-            device_id=user_info.device_id,
+            user_id=token.user_id,
+            device_id=token.device_id,
             access_token=access_token,
         )
 
         # delete pushers associated with this access token
-        if user_info.token_id is not None:
+        if token.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                user_info.user_id, (user_info.token_id,)
+                token.user_id, (token.token_id,)
             )
 
     async def delete_access_tokens_for_user(
@@ -1830,98 +1821,6 @@ class AuthHandler:
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s(slots=True, auto_attribs=True)
-class MacaroonGenerator:
-    hs: "HomeServer"
-
-    def generate_guest_access_token(self, user_id: str) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = access")
-        # Include a nonce, to make sure that each login gets a different
-        # access token.
-        macaroon.add_first_party_caveat(
-            "nonce = %s" % (stringutils.random_string_with_symbols(16),)
-        )
-        macaroon.add_first_party_caveat("guest = true")
-        return macaroon.serialize()
-
-    def generate_short_term_login_token(
-        self,
-        user_id: str,
-        auth_provider_id: str,
-        auth_provider_session_id: Optional[str] = None,
-        duration_in_ms: int = (2 * 60 * 1000),
-    ) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = login")
-        now = self.hs.get_clock().time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
-        if auth_provider_session_id is not None:
-            macaroon.add_first_party_caveat(
-                "auth_provider_session_id = %s" % (auth_provider_session_id,)
-            )
-        return macaroon.serialize()
-
-    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
-        """Verify a short-term-login macaroon
-
-        Checks that the given token is a valid, unexpired short-term-login token
-        minted by this server.
-
-        Args:
-            token: the login token to verify
-
-        Returns:
-            the user_id that this token is valid for
-
-        Raises:
-            MacaroonVerificationFailedException if the verification failed
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(token)
-        user_id = get_value_from_macaroon(macaroon, "user_id")
-        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
-
-        auth_provider_session_id: Optional[str] = None
-        try:
-            auth_provider_session_id = get_value_from_macaroon(
-                macaroon, "auth_provider_session_id"
-            )
-        except MacaroonVerificationFailedException:
-            pass
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = login")
-        v.satisfy_general(lambda c: c.startswith("user_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
-        satisfy_expiry(v, self.hs.get_clock().time_msec)
-        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
-
-        return LoginTokenAttributes(
-            user_id=user_id,
-            auth_provider_id=auth_provider_id,
-            auth_provider_session_id=auth_provider_session_id,
-        )
-
-    def generate_delete_pusher_token(self, user_id: str) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = delete_pusher")
-        return macaroon.serialize()
-
-    def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
-        macaroon = pymacaroons.Macaroon(
-            location=self.hs.config.server.server_name,
-            identifier="key",
-            key=self.hs.config.key.macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        return macaroon
-
-
 def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
     module_api = hs.get_module_api()
     for module, config in hs.config.authproviders.password_providers:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b79c551703..9c2c3a0e68 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,28 +119,33 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
-    @trace
-    @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(
-        self, user_id: str, from_token: StreamToken
-    ) -> JsonDict:
-        """Get list of users that have had the devices updated, or have newly
-        joined a room, that `user_id` may be interested in.
+    async def get_device_changes_in_shared_rooms(
+        self, user_id: str, room_ids: Collection[str], from_token: StreamToken
+    ) -> Collection[str]:
+        """Get the set of users whose devices have changed who share a room with
+        the given user.
         """
+        changed_users = await self.store.get_device_list_changes_in_rooms(
+            room_ids, from_token.device_list_key
+        )
 
-        set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
-        now_room_key = self.store.get_room_max_token()
+        if changed_users is not None:
+            # We also check if the given user has changed their device. If
+            # they're in no rooms then the above query won't include them.
+            changed = await self.store.get_users_whose_devices_changed(
+                from_token.device_list_key, [user_id]
+            )
+            changed_users.update(changed)
+            return changed_users
 
-        room_ids = await self.store.get_rooms_for_user(user_id)
+        # If the DB returned None then the `from_token` is too old, so we fall
+        # back on looking for device updates for all users.
 
-        # First we check if any devices have changed for users that we share
-        # rooms with.
         users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id
         )
@@ -153,6 +159,27 @@ class DeviceWorkerHandler:
             from_token.device_list_key, tracked_users
         )
 
+        return changed
+
+    @trace
+    @measure_func("device.get_user_ids_changed")
+    async def get_user_ids_changed(
+        self, user_id: str, from_token: StreamToken
+    ) -> JsonDict:
+        """Get list of users that have had the devices updated, or have newly
+        joined a room, that `user_id` may be interested in.
+        """
+
+        set_tag("user_id", user_id)
+        set_tag("from_token", str(from_token))
+        now_room_key = self.store.get_room_max_token()
+
+        room_ids = await self.store.get_rooms_for_user(user_id)
+
+        changed = await self.get_device_changes_in_shared_rooms(
+            user_id, room_ids, from_token
+        )
+
         # Then work out if any users have since joined
         rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
 
@@ -237,10 +264,19 @@ class DeviceWorkerHandler:
                         break
 
         if possibly_changed or possibly_left:
-            # Take the intersection of the users whose devices may have changed
-            # and those that actually still share a room with the user
-            possibly_joined = possibly_changed & users_who_share_room
-            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+            possibly_joined = possibly_changed
+            possibly_left = possibly_changed | possibly_left
+
+            # Double check if we still share rooms with the given user.
+            users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
+                possibly_left
+            )
+            for changed_user_id, entries in users_rooms.items():
+                if any(e.room_id in room_ids for e in entries):
+                    possibly_left.discard(changed_user_id)
+                else:
+                    possibly_joined.discard(changed_user_id)
+
         else:
             possibly_joined = set()
             possibly_left = set()
@@ -274,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -658,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
@@ -712,7 +752,13 @@ def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
-    device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+    device.update(
+        {
+            "last_seen_user_agent": ip.get("user_agent"),
+            "last_seen_ts": ip.get("last_seen"),
+            "last_seen_ip": ip.get("ip"),
+        }
+    )
 
 
 class DeviceListUpdater:
@@ -760,7 +806,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1459a046de..7127d5aefc 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -28,8 +28,9 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.appservice import ApplicationService
+from synapse.module_api import NOT_SPAM
 from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -82,8 +83,9 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.store.get_users_in_room(room_id)
-            servers = {get_domain_from_id(u) for u in users}
+            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+                room_id
+            )
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -132,7 +134,7 @@ class DirectoryHandler:
         else:
             # Server admins are not subject to the same constraints as normal
             # users when creating an alias (e.g. being in the room).
-            is_admin = await self.auth.is_server_admin(requester.user)
+            is_admin = await self.auth.is_server_admin(requester)
 
             if (self.require_membership and check_membership) and not is_admin:
                 rooms_for_user = await self.store.get_rooms_for_user(user_id)
@@ -141,10 +143,16 @@ class DirectoryHandler:
                         403, "You must be in the room to create an alias for it"
                     )
 
-            if not await self.spam_checker.user_may_create_room_alias(
+            spam_check = await self.spam_checker.user_may_create_room_alias(
                 user_id, room_alias
-            ):
-                raise AuthError(403, "This user is not permitted to create this alias")
+            )
+            if spam_check != self.spam_checker.NOT_SPAM:
+                raise AuthError(
+                    403,
+                    "This user is not permitted to create this alias",
+                    errcode=spam_check[0],
+                    additional_fields=spam_check[1],
+                )
 
             if not self.config.roomdirectory.is_alias_creation_allowed(
                 user_id, room_id, room_alias_str
@@ -190,7 +198,7 @@ class DirectoryHandler:
         user_id = requester.user.to_string()
 
         try:
-            can_delete = await self._user_can_delete_alias(room_alias, user_id)
+            can_delete = await self._user_can_delete_alias(room_alias, requester)
         except StoreError as e:
             if e.code == 404:
                 raise NotFoundError("Unknown room alias")
@@ -280,8 +288,9 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        users = await self.store.get_users_in_room(room_id)
-        extra_servers = {get_domain_from_id(u) for u in users}
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
@@ -393,7 +402,9 @@ class DirectoryHandler:
         # either no interested services, or no service with an exclusive lock
         return True
 
-    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
+    async def _user_can_delete_alias(
+        self, alias: RoomAlias, requester: Requester
+    ) -> bool:
         """Determine whether a user can delete an alias.
 
         One of the following must be true:
@@ -406,7 +417,7 @@ class DirectoryHandler:
         """
         creator = await self.store.get_room_alias_creator(alias.to_string())
 
-        if creator == user_id:
+        if creator == requester.user.to_string():
             return True
 
         # Resolve the alias to the corresponding room.
@@ -415,9 +426,7 @@ class DirectoryHandler:
         if not room_id:
             return False
 
-        return await self.auth.check_can_change_room_list(
-            room_id, UserID.from_string(user_id)
-        )
+        return await self.auth.check_can_change_room_list(room_id, requester)
 
     async def edit_published_room_list(
         self, requester: Requester, room_id: str, visibility: str
@@ -430,9 +439,13 @@ class DirectoryHandler:
         """
         user_id = requester.user.to_string()
 
-        if not await self.spam_checker.user_may_publish_room(user_id, room_id):
+        spam_check = await self.spam_checker.user_may_publish_room(user_id, room_id)
+        if spam_check != NOT_SPAM:
             raise AuthError(
-                403, "This user is not permitted to publish rooms to the room list"
+                403,
+                "This user is not permitted to publish rooms to the room list",
+                errcode=spam_check[0],
+                additional_fields=spam_check[1],
             )
 
         if requester.is_guest:
@@ -452,7 +465,7 @@ class DirectoryHandler:
             raise SynapseError(400, "Unknown room")
 
         can_change_room_list = await self.auth.check_can_change_room_list(
-            room_id, requester.user
+            room_id, requester
         )
         if not can_change_room_list:
             raise AuthError(
@@ -517,10 +530,8 @@ class DirectoryHandler:
         Get a list of the aliases that currently point to this room on this server
         """
         # allow access to server admins and current members of the room
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
         if not is_admin:
-            await self.auth.check_user_in_room_or_world_readable(
-                room_id, requester.user.to_string()
-            )
+            await self.auth.check_user_in_room_or_world_readable(room_id, requester)
 
         return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..c938339ddd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
 
     @trace
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +124,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,8 +138,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -341,7 +343,7 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
@@ -392,7 +394,7 @@ class E2eKeysHandler:
 
     @trace
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -403,7 +405,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -461,7 +463,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -475,8 +477,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -506,7 +508,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -609,7 +611,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 446f509bdc..28dc08c22a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
 
 from typing_extensions import Literal
 
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
                 user_id, version, room_id, session_id
             )
 
-            log_kv(results)
+            log_kv(cast(JsonDict, results))
             return results
 
     @trace
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 6bed464351..c3ddc5d182 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -23,7 +23,10 @@ from synapse.api.constants import (
 )
 from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
-from synapse.event_auth import check_auth_rules_for_event
+from synapse.event_auth import (
+    check_state_dependent_auth_rules,
+    check_state_independent_auth_rules,
+)
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
@@ -48,14 +51,14 @@ class EventAuthHandler:
 
     async def check_auth_rules_from_context(
         self,
-        room_version_obj: RoomVersion,
         event: EventBase,
         context: EventContext,
     ) -> None:
         """Check an event passes the auth rules at its own auth events"""
+        await check_state_independent_auth_rules(self._store, event)
         auth_event_ids = event.auth_event_ids()
         auth_events_by_id = await self._store.get_events(auth_event_ids)
-        check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values())
+        check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
         self,
@@ -126,12 +129,9 @@ class EventAuthHandler:
         else:
             users = {}
 
-        # Find the user with the highest power level.
-        users_in_room = await self._store.get_users_in_room(room_id)
-        # Only interested in local users.
-        local_users_in_room = [
-            u for u in users_in_room if get_domain_from_id(u) == self._server_name
-        ]
+        # Find the user with the highest power level (only interested in local
+        # users).
+        local_users_in_room = await self._store.get_local_users_in_room(room_id)
         chosen_user = max(
             local_users_in_room,
             key=lambda user: users.get(user, users_default_level),
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ac13340d3a..949b69cb41 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -151,7 +151,7 @@ class EventHandler:
         """Retrieve a single specified event.
 
         Args:
-            user: The user requesting the event
+            user: The local user requesting the event
             room_id: The expected room id. We'll return None if the
                 event's room does not match.
             event_id: The event ID to obtain.
@@ -173,8 +173,11 @@ class EventHandler:
         if not event:
             return None
 
-        users = await self.store.get_users_in_room(event.room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=event.room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
             self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6a143440d3..dd4b9f66d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,6 +32,7 @@ from typing import (
 )
 
 import attr
+from prometheus_client import Histogram
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -45,6 +46,7 @@ from synapse.api.errors import (
     FederationDeniedError,
     FederationError,
     HttpResponseException,
+    LimitExceededError,
     NotFoundError,
     RequestSendFailed,
     SynapseError,
@@ -58,14 +60,17 @@ from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.module_api import NOT_SPAM
 from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
+from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -75,36 +80,28 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-    """Get joined domains from state
-
-    Args:
-        state: State map from type/state key to event.
-
-    Returns:
-        Returns a list of servers with the lowest depth of their joins.
-            Sorted by lowest depth first.
-    """
-    joined_users = [
-        (state_key, int(event.depth))
-        for (e_type, state_key), event in state.items()
-        if e_type == EventTypes.Member and event.membership == Membership.JOIN
-    ]
-
-    joined_domains: Dict[str, int] = {}
-    for u, d in joined_users:
-        try:
-            dom = get_domain_from_id(u)
-            old_d = joined_domains.get(dom)
-            if old_d:
-                joined_domains[dom] = min(d, old_d)
-            else:
-                joined_domains[dom] = d
-        except Exception:
-            pass
-
-    return sorted(joined_domains.items(), key=lambda d: d[1])
+# Added to debug performance and track progress on optimizations
+backfill_processing_before_timer = Histogram(
+    "synapse_federation_backfill_processing_before_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        30.0,
+        40.0,
+        60.0,
+        80.0,
+        "+Inf",
+    ),
+)
 
 
 class _BackfillPointType(Enum):
@@ -134,6 +131,7 @@ class FederationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
+        self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
@@ -177,6 +175,7 @@ class FederationHandler:
                 "resume_sync_partial_state_room", self._resume_sync_partial_state_room
             )
 
+    @trace
     async def maybe_backfill(
         self, room_id: str, current_depth: int, limit: int
     ) -> bool:
@@ -192,12 +191,39 @@ class FederationHandler:
                 return. This is used as part of the heuristic to decide if we
                 should back paginate.
         """
+        # Starting the processing time here so we can include the room backfill
+        # linearizer lock queue in the timing
+        processing_start_time = self.clock.time_msec()
+
         async with self._room_backfill.queue(room_id):
-            return await self._maybe_backfill_inner(room_id, current_depth, limit)
+            return await self._maybe_backfill_inner(
+                room_id,
+                current_depth,
+                limit,
+                processing_start_time=processing_start_time,
+            )
 
     async def _maybe_backfill_inner(
-        self, room_id: str, current_depth: int, limit: int
+        self,
+        room_id: str,
+        current_depth: int,
+        limit: int,
+        *,
+        processing_start_time: int,
     ) -> bool:
+        """
+        Checks whether the `current_depth` is at or approaching any backfill
+        points in the room and if so, will backfill. We only care about
+        checking backfill points that happened before the `current_depth`
+        (meaning less than or equal to the `current_depth`).
+
+        Args:
+            room_id: The room to backfill in.
+            current_depth: The depth to check at for any upcoming backfill points.
+            limit: The max number of events to request from the remote federated server.
+            processing_start_time: The time when `maybe_backfill` started
+                processing. Only used for timing.
+        """
         backwards_extremities = [
             _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
             for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
@@ -365,23 +391,29 @@ class FederationHandler:
         logger.debug(
             "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request",
+            str(extremities_to_request),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+            str(len(extremities_to_request)),
+        )
 
         # Now we need to decide which hosts to hit first.
-
-        # First we try hosts that are already in the room
+        # First we try hosts that are already in the room.
         # TODO: HEURISTIC ALERT.
+        likely_domains = (
+            await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+        )
 
-        curr_state = await self._storage_controllers.state.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
-
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
+                # We don't want to ask our own server for information we don't have
+                if dom == self.server_name:
+                    continue
+
                 try:
                     await self._federation_event_handler.backfill(
                         dom, room_id, limit=100, extremities=extremities_to_request
@@ -420,6 +452,11 @@ class FederationHandler:
 
             return False
 
+        processing_end_time = self.clock.time_msec()
+        backfill_processing_before_timer.observe(
+            (processing_end_time - processing_start_time) / 1000
+        )
+
         success = await try_backfill(likely_domains)
         if success:
             return True
@@ -543,30 +580,49 @@ class FederationHandler:
             )
 
             if ret.partial_state:
-                # TODO(faster_joins): roll this back if we don't manage to start the
-                #   background resync (eg process_remote_join fails)
+                # Mark the room as having partial state.
+                # The background process is responsible for unmarking this flag,
+                # even if the join fails.
                 await self.store.store_partial_state_room(room_id, ret.servers_in_room)
 
-            max_stream_id = await self._federation_event_handler.process_remote_join(
-                origin,
-                room_id,
-                auth_chain,
-                state,
-                event,
-                room_version_obj,
-                partial_state=ret.partial_state,
-            )
-
-            if ret.partial_state:
-                # Kick off the process of asynchronously fetching the state for this
-                # room.
-                run_as_background_process(
-                    desc="sync_partial_state_room",
-                    func=self._sync_partial_state_room,
-                    initial_destination=origin,
-                    other_destinations=ret.servers_in_room,
-                    room_id=room_id,
+            try:
+                max_stream_id = (
+                    await self._federation_event_handler.process_remote_join(
+                        origin,
+                        room_id,
+                        auth_chain,
+                        state,
+                        event,
+                        room_version_obj,
+                        partial_state=ret.partial_state,
+                    )
                 )
+            except PartialStateConflictError as e:
+                # The homeserver was already in the room and it is no longer partial
+                # stated. We ought to be doing a local join instead. Turn the error into
+                # a 429, as a hint to the client to try again.
+                # TODO(faster_joins): `_should_perform_remote_join` suggests that we may
+                #   do a remote join for restricted rooms even if we have full state.
+                logger.error(
+                    "Room %s was un-partial stated while processing remote join.",
+                    room_id,
+                )
+                raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
+            finally:
+                # Always kick off the background process that asynchronously fetches
+                # state for the room.
+                # If the join failed, the background process is responsible for
+                # cleaning up — including unmarking the room as a partial state room.
+                if ret.partial_state:
+                    # Kick off the process of asynchronously fetching the state for this
+                    # room.
+                    run_as_background_process(
+                        desc="sync_partial_state_room",
+                        func=self._sync_partial_state_room,
+                        initial_destination=origin,
+                        other_destinations=ret.servers_in_room,
+                        room_id=room_id,
+                    )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
@@ -730,6 +786,23 @@ class FederationHandler:
         # (and return a 404 otherwise)
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /make_join, so return a 404 as we would if we weren't in the
+            # room at all.
+            # The main reason we can't respond properly is that we need to know about
+            # the auth events for the join event that we would return.
+            # We also should not bother entertaining the /make_join since we cannot
+            # handle the /send_join.
+            logger.info(
+                "Rejecting /make_join to %s because it's a partial state room", room_id
+            )
+            raise SynapseError(
+                404,
+                "Unable to handle /make_join right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
             room_id, self.server_name
@@ -799,9 +872,7 @@ class FederationHandler:
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        await self._event_auth_handler.check_auth_rules_from_context(
-            room_version, event, context
-        )
+        await self._event_auth_handler.check_auth_rules_from_context(event, context)
         return event
 
     async def on_invite_request(
@@ -821,11 +892,15 @@ class FederationHandler:
         if self.hs.config.server.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
-        if not await self.spam_checker.user_may_invite(
+        spam_check = await self.spam_checker.user_may_invite(
             event.sender, event.state_key, event.room_id
-        ):
+        )
+        if spam_check != NOT_SPAM:
             raise SynapseError(
-                403, "This user is not permitted to send invites to this server/user"
+                403,
+                "This user is not permitted to send invites to this server/user",
+                errcode=spam_check[0],
+                additional_fields=spam_check[1],
             )
 
         membership = event.content.get("membership")
@@ -972,9 +1047,7 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            await self._event_auth_handler.check_auth_rules_from_context(
-                room_version_obj, event, context
-            )
+            await self._event_auth_handler.check_auth_rules_from_context(event, context)
         except AuthError as e:
             logger.warning("Failed to create new leave %r because %s", event, e)
             raise e
@@ -1033,15 +1106,15 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_knock_request`
-            await self._event_auth_handler.check_auth_rules_from_context(
-                room_version_obj, event, context
-            )
+            await self._event_auth_handler.check_auth_rules_from_context(event, context)
         except AuthError as e:
             logger.warning("Failed to create new knock %r because %s", event, e)
             raise e
 
         return event
 
+    @trace
+    @tag_args
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1206,9 +1279,9 @@ class FederationHandler:
             event.internal_metadata.send_on_behalf_of = self.hs.hostname
 
             try:
-                validate_event_for_room_version(room_version_obj, event)
+                validate_event_for_room_version(event)
                 await self._event_auth_handler.check_auth_rules_from_context(
-                    room_version_obj, event, context
+                    event, context
                 )
             except AuthError as e:
                 logger.warning("Denying new third party invite %r because %s", event, e)
@@ -1258,10 +1331,8 @@ class FederationHandler:
         )
 
         try:
-            validate_event_for_room_version(room_version_obj, event)
-            await self._event_auth_handler.check_auth_rules_from_context(
-                room_version_obj, event, context
-            )
+            validate_event_for_room_version(event)
+            await self._event_auth_handler.check_auth_rules_from_context(event, context)
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
@@ -1506,14 +1577,17 @@ class FederationHandler:
         # TODO(faster_joins): do we need to lock to avoid races? What happens if other
         #   worker processes kick off a resync in parallel? Perhaps we should just elect
         #   a single worker to do the resync.
+        #   https://github.com/matrix-org/synapse/issues/12994
         #
         # TODO(faster_joins): what happens if we leave the room during a resync? if we
         #   really leave, that might mean we have difficulty getting the room state over
         #   federation.
+        #   https://github.com/matrix-org/synapse/issues/12802
         #
         # TODO(faster_joins): we need some way of prioritising which homeservers in
         #   `other_destinations` to try first, otherwise we'll spend ages trying dead
         #   homeservers for large rooms.
+        #   https://github.com/matrix-org/synapse/issues/12999
 
         if initial_destination is None and len(other_destinations) == 0:
             raise ValueError(
@@ -1522,15 +1596,16 @@ class FederationHandler:
 
         # Make an infinite iterator of destinations to try. Once we find a working
         # destination, we'll stick with it until it flakes.
+        destinations: Collection[str]
         if initial_destination is not None:
             # Move `initial_destination` to the front of the list.
             destinations = list(other_destinations)
             if initial_destination in destinations:
                 destinations.remove(initial_destination)
             destinations = [initial_destination] + destinations
-            destination_iter = itertools.cycle(destinations)
         else:
-            destination_iter = itertools.cycle(other_destinations)
+            destinations = other_destinations
+        destination_iter = itertools.cycle(destinations)
 
         # `destination` is the current remote homeserver we're pulling from.
         destination = next(destination_iter)
@@ -1543,12 +1618,9 @@ class FederationHandler:
                 # all the events are updated, so we can update current state and
                 # clear the lazy-loading flag.
                 logger.info("Updating current state for %s", room_id)
-                assert (
-                    self._storage_controllers.persistence is not None
-                ), "TODO(faster_joins): support for workers"
-                await self._storage_controllers.persistence.update_current_state(
-                    room_id
-                )
+                # TODO(faster_joins): notify workers in notify_room_un_partial_stated
+                #   https://github.com/matrix-org/synapse/issues/12994
+                await self.state_handler.update_current_state(room_id)
 
                 logger.info("Clearing partial-state flag for %s", room_id)
                 success = await self.store.clear_partial_state_room(room_id)
@@ -1559,13 +1631,12 @@ class FederationHandler:
                     )
 
                     # TODO(faster_joins) update room stats and user directory?
+                    #   https://github.com/matrix-org/synapse/issues/12814
+                    #   https://github.com/matrix-org/synapse/issues/12815
                     return
 
                 # we raced against more events arriving with partial state. Go round
                 # the loop again. We've already logged a warning, so no need for more.
-                # TODO(faster_joins): there is still a race here, whereby incoming events which raced
-                #   with us will fail to be persisted after the call to `clear_partial_state_room` due to
-                #   having partial state.
                 continue
 
             events = await self.store.get_events_as_list(
@@ -1588,6 +1659,7 @@ class FederationHandler:
                             #   indefinitely is also not the right thing to do if we can
                             #   reach all homeservers and they all claim they don't have
                             #   the state we want.
+                            #   https://github.com/matrix-org/synapse/issues/13000
                             logger.error(
                                 "Failed to get state for %s at %s from %s because %s, "
                                 "giving up!",
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 87a0608359..ace7adcffb 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -28,7 +29,7 @@ from typing import (
     Tuple,
 )
 
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -50,19 +51,28 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
 from synapse.event_auth import (
     auth_types_for_event,
-    check_auth_rules_for_event,
+    check_state_dependent_auth_rules,
+    check_state_independent_auth_rules,
     validate_event_for_room_version,
 )
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.logging.context import nested_logging_context, run_in_background
+from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import (
+    SynapseTags,
+    set_tag,
+    start_active_span,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.state import StateResolutionStore
+from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -88,6 +98,36 @@ soft_failed_event_counter = Counter(
     "Events received over federation that we marked as soft_failed",
 )
 
+# Added to debug performance and track progress on optimizations
+backfill_processing_after_timer = Histogram(
+    "synapse_federation_backfill_processing_after_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.25,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        25.0,
+        30.0,
+        40.0,
+        50.0,
+        60.0,
+        80.0,
+        100.0,
+        120.0,
+        150.0,
+        180.0,
+        "+Inf",
+    ),
+)
+
 
 class FederationEventHandler:
     """Handles events that originated from federation.
@@ -274,7 +314,18 @@ class FederationEventHandler:
                     affected=pdu.event_id,
                 )
 
-        await self._process_received_pdu(origin, pdu, state_ids=None)
+        try:
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
+        except PartialStateConflictError:
+            # The room was un-partial stated while we were processing the PDU.
+            # Try once more, with full state this time.
+            logger.info(
+                "Room %s was un-partial stated while processing the PDU, trying again.",
+                room_id,
+            )
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -304,7 +355,11 @@ class FederationEventHandler:
             The event and context of the event after inserting it into the room graph.
 
         Raises:
+            RuntimeError if any prev_events are missing
             SynapseError if the event is not accepted into the room
+            PartialStateConflictError if the room was un-partial stated in between
+                computing the state at the event and persisting it. The caller should
+                retry exactly once in this case.
         """
         logger.debug(
             "on_send_membership_event: Got event: %s, signatures: %s",
@@ -333,7 +388,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -361,7 +416,7 @@ class FederationEventHandler:
         # need to.
         await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
 
-        await self._check_for_soft_fail(event, None, origin=origin)
+        await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
         return event, context
 
@@ -391,6 +446,7 @@ class FederationEventHandler:
             prev_member_event,
         )
 
+    @trace
     async def process_remote_join(
         self,
         origin: str,
@@ -422,6 +478,8 @@ class FederationEventHandler:
 
         Raises:
             SynapseError if the response is in some way invalid.
+            PartialStateConflictError if the homeserver is already in the room and it
+                has been un-partial stated.
         """
         create_event = None
         for e in state:
@@ -469,7 +527,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -517,31 +575,36 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
-            # build a new state group for it if need be
-            context = await self._state_handler.compute_event_context(
-                event,
-                state_ids_before_event=state_ids,
+            context = await self._compute_event_context_with_maybe_missing_prevs(
+                destination, event
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
-                # partial state - ie, an event has an earlier stream_ordering than one
-                # or more of its prev_events, so we de-partial-state it before its
-                # prev_events.
+                # partial state. We were careful to only pick events from the db without
+                # partial-state prev events, so that implies that a prev event has
+                # been persisted (with partial state) since we did the query.
                 #
-                # TODO(faster_joins): we probably need to be more intelligent, and
-                #    exclude partial-state prev_events from consideration
+                # So, let's just ignore `event` for now; when we re-run the db query
+                # we should instead get its partial-state prev event, which we will
+                # de-partial-state, and then come back to event.
                 logger.warning(
-                    "%s still has partial state: can't de-partial-state it yet",
+                    "%s still has prev_events with partial state: can't de-partial-state it yet",
                     event.event_id,
                 )
                 return
+
+            # since the state at this event has changed, we should now re-evaluate
+            # whether it should have been rejected. We must already have all of the
+            # auth events (from last time we went round this path), so there is no
+            # need to pass the origin.
+            await self._check_event_auth(None, event, context)
+
             await self._store.update_state_for_partial_state_event(event, context)
             self._state_storage_controller.notify_event_un_partial_stated(
                 event.event_id
             )
 
+    @trace
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> None:
@@ -571,21 +634,23 @@ class FederationEventHandler:
         if not events:
             return
 
-        # if there are any events in the wrong room, the remote server is buggy and
-        # should not be trusted.
-        for ev in events:
-            if ev.room_id != room_id:
-                raise InvalidResponseError(
-                    f"Remote server {dest} returned event {ev.event_id} which is in "
-                    f"room {ev.room_id}, when we were backfilling in {room_id}"
-                )
+        with backfill_processing_after_timer.time():
+            # if there are any events in the wrong room, the remote server is buggy and
+            # should not be trusted.
+            for ev in events:
+                if ev.room_id != room_id:
+                    raise InvalidResponseError(
+                        f"Remote server {dest} returned event {ev.event_id} which is in "
+                        f"room {ev.room_id}, when we were backfilling in {room_id}"
+                    )
 
-        await self._process_pulled_events(
-            dest,
-            events,
-            backfilled=True,
-        )
+            await self._process_pulled_events(
+                dest,
+                events,
+                backfilled=True,
+            )
 
+    @trace
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
     ) -> None:
@@ -686,8 +751,9 @@ class FederationEventHandler:
         logger.info("Got %d prev_events", len(missing_events))
         await self._process_pulled_events(origin, missing_events, backfilled=False)
 
+    @trace
     async def _process_pulled_events(
-        self, origin: str, events: Iterable[EventBase], backfilled: bool
+        self, origin: str, events: Collection[EventBase], backfilled: bool
     ) -> None:
         """Process a batch of events we have pulled from a remote server
 
@@ -702,6 +768,15 @@ class FederationEventHandler:
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str([event.event_id for event in events]),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(events)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
         logger.debug(
             "processing pulled backfilled=%s events=%s",
             backfilled,
@@ -724,6 +799,8 @@ class FederationEventHandler:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
 
+    @trace
+    @tag_args
     async def _process_pulled_event(
         self, origin: str, event: EventBase, backfilled: bool
     ) -> None:
@@ -748,10 +825,24 @@ class FederationEventHandler:
         """
         logger.info("Processing pulled event %s", event)
 
-        # these should not be outliers.
-        assert (
-            not event.internal_metadata.is_outlier()
-        ), "pulled event unexpectedly flagged as outlier"
+        # This function should not be used to persist outliers (use something
+        # else) because this does a bunch of operations that aren't necessary
+        # (extra work; in particular, it makes sure we have all the prev_events
+        # and resolves the state across those prev events). If you happen to run
+        # into a situation where the event you're trying to process/backfill is
+        # marked as an `outlier`, then you should update that spot to return an
+        # `EventBase` copy that doesn't have `outlier` flag set.
+        #
+        # `EventBase` is used to represent both an event we have not yet
+        # persisted, and one that we have persisted and now keep in the cache.
+        # In an ideal world this method would only be called with the first type
+        # of event, but it turns out that's not actually the case and for
+        # example, you could get an event from cache that is marked as an
+        # `outlier` (fix up that spot though).
+        assert not event.internal_metadata.is_outlier(), (
+            "Outlier event passed to _process_pulled_event. "
+            "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`."
+        )
 
         event_id = event.event_id
 
@@ -761,7 +852,7 @@ class FederationEventHandler:
         if existing:
             if not existing.internal_metadata.is_outlier():
                 logger.info(
-                    "Ignoring received event %s which we have already seen",
+                    "_process_pulled_event: Ignoring received event %s which we have already seen",
                     event_id,
                 )
                 return
@@ -774,28 +865,56 @@ class FederationEventHandler:
             return
 
         try:
-            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
-            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
-            #   not return partial state
+            try:
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
+            except PartialStateConflictError:
+                # The room was un-partial stated while we were processing the event.
+                # Try once more, with full state this time.
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
 
-            await self._process_received_pdu(
-                origin, event, state_ids=state_ids, backfilled=backfilled
-            )
+                # We ought to have full state now, barring some unlikely race where we left and
+                # rejoned the room in the background.
+                if context.partial_state:
+                    raise AssertionError(
+                        f"Event {event.event_id} still has a partial resolved state "
+                        f"after room {event.room_id} was un-partial stated"
+                    )
+
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
         except FederationError as e:
             if e.code == 403:
                 logger.warning("Pulled event %s failed history check.", event_id)
             else:
                 raise
 
-    async def _resolve_state_at_missing_prevs(
+    @trace
+    async def _compute_event_context_with_maybe_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[StateMap[str]]:
-        """Calculate the state at an event with missing prev_events.
+    ) -> EventContext:
+        """Build an EventContext structure for a non-outlier event whose prev_events may
+        be missing.
 
-        This is used when we have pulled a batch of events from a remote server, and
-        still don't have all the prev_events.
+        This is used when we have pulled a batch of events from a remote server, and may
+        not have all the prev_events.
 
-        If we already have all the prev_events for `event`, this method does nothing.
+        To build an EventContext, we need to calculate the state before the event. If we
+        already have all the prev_events for `event`, we can simply use the state after
+        the prev_events to calculate the state before `event`.
 
         Otherwise, the missing prevs become new backwards extremities, and we fall back
         to asking the remote server for the state after each missing `prev_event`,
@@ -816,8 +935,7 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns
-            the event ids of the state at `event`.
+            The event context.
 
         Raises:
             FederationError if we fail to get the state from the remote server after any
@@ -831,7 +949,7 @@ class FederationEventHandler:
         missing_prevs = prevs - seen
 
         if not missing_prevs:
-            return None
+            return await self._state_handler.compute_event_context(event)
 
         logger.info(
             "Event %s is missing prev_events %s: calculating state for a "
@@ -843,9 +961,15 @@ class FederationEventHandler:
         # resolve them to find the correct state at the current event.
 
         try:
+            # Determine whether we may be about to retrieve partial state
+            # Events may be un-partial stated right after we compute the partial state
+            # flag, but that's okay, as long as the flag errs on the conservative side.
+            partial_state_flags = await self._store.get_partial_state_events(seen)
+            partial_state = any(partial_state_flags.values())
+
             # Get the state of the events we know about
             ours = await self._state_storage_controller.get_state_groups_ids(
-                room_id, seen
+                room_id, seen, await_full_state=False
             )
 
             # state_maps is a list of mappings from (type, state_key) to event_id
@@ -891,8 +1015,12 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state_map
+        return await self._state_handler.compute_event_context(
+            event, state_ids_before_event=state_map, partial_state=partial_state
+        )
 
+    @trace
+    @tag_args
     async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
@@ -913,6 +1041,14 @@ class FederationEventHandler:
             InvalidResponseError: if the remote homeserver's response contains fields
                 of the wrong type.
         """
+
+        # It would be better if we could query the difference from our known
+        # state to the given `event_id` so the sending server doesn't have to
+        # send as much and we don't have to process as many events. For example
+        # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k
+        # auth_events) from this call.
+        #
+        # Tracked by https://github.com/matrix-org/synapse/issues/13618
         (
             state_event_ids,
             auth_event_ids,
@@ -932,10 +1068,10 @@ class FederationEventHandler:
         logger.debug("Fetching %i events from cache/store", len(desired_events))
         have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - have_events
+        missing_desired_event_ids = desired_events - have_events
         logger.debug(
             "We are missing %i events (got %i)",
-            len(missing_desired_events),
+            len(missing_desired_event_ids),
             len(have_events),
         )
 
@@ -947,13 +1083,30 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - have_events
-        missing_auth_events.difference_update(
-            await self._store.have_seen_events(room_id, missing_auth_events)
+        missing_auth_event_ids = set(auth_event_ids) - have_events
+        missing_auth_event_ids.difference_update(
+            await self._store.have_seen_events(room_id, missing_auth_event_ids)
         )
-        logger.debug("We are also missing %i auth events", len(missing_auth_events))
+        logger.debug("We are also missing %i auth events", len(missing_auth_event_ids))
 
-        missing_events = missing_desired_events | missing_auth_events
+        missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
+
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+            str(missing_auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+            str(len(missing_auth_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+            str(missing_desired_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+            str(len(missing_desired_event_ids)),
+        )
 
         # Making an individual request for each of 1000s of events has a lot of
         # overhead. On the other hand, we don't really want to fetch all of the events
@@ -964,13 +1117,13 @@ class FederationEventHandler:
         #
         # TODO: might it be better to have an API which lets us do an aggregate event
         #   request
-        if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+        if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
             logger.debug("Requesting complete state from remote")
             await self._get_state_and_persist(destination, room_id, event_id)
         else:
-            logger.debug("Fetching %i events from remote", len(missing_events))
+            logger.debug("Fetching %i events from remote", len(missing_event_ids))
             await self._get_events_and_persist(
-                destination=destination, room_id=room_id, event_ids=missing_events
+                destination=destination, room_id=room_id, event_ids=missing_event_ids
             )
 
         # We now need to fill out the state map, which involves fetching the
@@ -1018,12 +1171,23 @@ class FederationEventHandler:
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
         failed_to_fetch = desired_events - event_metadata.keys()
+        # `event_id` could be missing from `event_metadata` because it's not necessarily
+        # a state event. We've already checked that we've fetched it above.
+        failed_to_fetch.discard(event_id)
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state events for %s %s",
                 event_id,
                 failed_to_fetch,
             )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+                str(failed_to_fetch),
+            )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+                str(len(failed_to_fetch)),
+            )
 
         if remote_event.is_state() and remote_event.rejected_reason is None:
             state_map[
@@ -1032,6 +1196,8 @@ class FederationEventHandler:
 
         return state_map
 
+    @trace
+    @tag_args
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1053,11 +1219,12 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=(event_id,)
             )
 
+    @trace
     async def _process_received_pdu(
         self,
         origin: str,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1079,37 +1246,30 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state_ids: Normally None, but if we are handling a gap in the graph
-                (ie, we are missing one or more prev_events), the resolved state at the
-                event
+            context: The `EventContext` to persist the event with.
 
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
+
+        PartialStateConflictError: if the room was un-partial stated in between
+            computing the state at the event and persisting it. The caller should
+            recompute `context` and retry exactly once when this happens.
         """
         logger.debug("Processing event: %s", event)
         assert not event.internal_metadata.outlier
 
         try:
-            context = await self._state_handler.compute_event_context(
-                event,
-                state_ids_before_event=state_ids,
-            )
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
-            # FIXME richvdh 2021/10/07 I don't think this is reachable. Let's log it
-            #   for now
-            logger.exception("Unexpected AuthError from _check_event_auth")
+            # This happens only if we couldn't find the auth events. We'll already have
+            # logged a warning, so now we just convert to a FederationError.
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
         if not backfilled and not context.rejected:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state_ids, origin=origin)
+            await self._check_for_soft_fail(event, context=context, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1210,6 +1370,7 @@ class FederationEventHandler:
         except Exception:
             logger.exception("Failed to resync device for %s", sender)
 
+    @trace
     async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
         """Handles backfilling the insertion event when we receive a marker
         event that points to one.
@@ -1241,7 +1402,7 @@ class FederationEventHandler:
         logger.debug("_handle_marker_event: received %s", marker_event)
 
         insertion_event_id = marker_event.content.get(
-            EventContentFields.MSC2716_MARKER_INSERTION
+            EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
         )
 
         if insertion_event_id is None:
@@ -1294,6 +1455,55 @@ class FederationEventHandler:
             marker_event,
         )
 
+    async def backfill_event_id(
+        self, destination: str, room_id: str, event_id: str
+    ) -> EventBase:
+        """Backfill a single event and persist it as a non-outlier which means
+        we also pull in all of the state and auth events necessary for it.
+
+        Args:
+            destination: The homeserver to pull the given event_id from.
+            room_id: The room where the event is from.
+            event_id: The event ID to backfill.
+
+        Raises:
+            FederationError if we are unable to find the event from the destination
+        """
+        logger.info(
+            "backfill_event_id: event_id=%s from destination=%s", event_id, destination
+        )
+
+        room_version = await self._store.get_room_version(room_id)
+
+        event_from_response = await self._federation_client.get_pdu(
+            [destination],
+            event_id,
+            room_version,
+        )
+
+        if not event_from_response:
+            raise FederationError(
+                "ERROR",
+                404,
+                "Unable to find event_id=%s from destination=%s to backfill."
+                % (event_id, destination),
+                affected=event_id,
+            )
+
+        # Persist the event we just fetched, including pulling all of the state
+        # and auth events to de-outlier it. This also sets up the necessary
+        # `state_groups` for the event.
+        await self._process_pulled_events(
+            destination,
+            [event_from_response],
+            # Prevent notifications going to clients
+            backfilled=True,
+        )
+
+        return event_from_response
+
+    @trace
+    @tag_args
     async def _get_events_and_persist(
         self, destination: str, room_id: str, event_ids: Collection[str]
     ) -> None:
@@ -1339,6 +1549,7 @@ class FederationEventHandler:
         logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
         await self._auth_and_persist_outliers(room_id, events)
 
+    @trace
     async def _auth_and_persist_outliers(
         self, room_id: str, events: Iterable[EventBase]
     ) -> None:
@@ -1357,6 +1568,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        event_ids = event_map.keys()
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         # filter out any events we have already seen. This might happen because
         # the events were eagerly pushed to us (eg, during a room join), or because
         # another thread has raced against us since we decided to request the event.
@@ -1428,10 +1649,9 @@ class FederationEventHandler:
             allow_rejected=True,
         )
 
-        room_version = await self._store.get_room_version_id(room_id)
-        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+        events_and_contexts_to_persist: List[Tuple[EventBase, EventContext]] = []
 
-        def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+        async def prep(event: EventBase) -> None:
             with nested_logging_context(suffix=event.event_id):
                 auth = []
                 for auth_event_id in event.auth_event_ids():
@@ -1445,7 +1665,7 @@ class FederationEventHandler:
                             event,
                             auth_event_id,
                         )
-                        return None
+                        return
                     auth.append(ae)
 
                 # we're not bothering about room state, so flag the event as an outlier.
@@ -1453,42 +1673,42 @@ class FederationEventHandler:
 
                 context = EventContext.for_outlier(self._storage_controllers)
                 try:
-                    validate_event_for_room_version(room_version_obj, event)
-                    check_auth_rules_for_event(room_version_obj, event, auth)
+                    validate_event_for_room_version(event)
+                    await check_state_independent_auth_rules(self._store, event)
+                    check_state_dependent_auth_rules(event, auth)
                 except AuthError as e:
                     logger.warning("Rejecting %r because %s", event, e)
                     context.rejected = RejectedReason.AUTH_ERROR
 
-            return event, context
+            events_and_contexts_to_persist.append((event, context))
+
+        for event in fetched_events:
+            await prep(event)
 
-        events_to_persist = (x for x in (prep(event) for event in fetched_events) if x)
         await self.persist_events_and_notify(
             room_id,
-            tuple(events_to_persist),
+            events_and_contexts_to_persist,
             # Mark these events backfilled as they're historic events that will
             # eventually be backfilled. For example, missing events we fetch
             # during backfill should be marked as backfilled as well.
             backfilled=True,
         )
 
+    @trace
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: Optional[str], event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin: The host the event originates from.
+            origin: The host the event originates from. This is used to fetch
+               any missing auth events. It can be set to None, but only if we are
+               sure that we already have all the auth events.
             event: The event itself.
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1497,16 +1717,13 @@ class FederationEventHandler:
         assert not event.internal_metadata.outlier
 
         # first of all, check that the event itself is valid.
-        room_version = await self._store.get_room_version_id(event.room_id)
-        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
         try:
-            validate_event_for_room_version(room_version_obj, event)
+            validate_event_for_room_version(event)
         except AuthError as e:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1516,66 +1733,112 @@ class FederationEventHandler:
         claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
             origin, event
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+            str([ev.event_id for ev in claimed_auth_events]),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+            str(len(claimed_auth_events)),
+        )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
-            check_auth_rules_for_event(room_version_obj, event, claimed_auth_events)
+            await check_state_independent_auth_rules(self._store, event)
+            check_state_dependent_auth_rules(event, claimed_auth_events)
         except AuthError as e:
             logger.warning(
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
-            )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_auth_rules_for_event(
-                room_version_obj, event, auth_events_for_auth.values()
+                e,
             )
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
+    @trace
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1593,17 +1856,27 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
         such.
 
+        Does nothing for events in rooms with partial state, since we may not have an
+        accurate membership event for the sender in the current state.
+
         Args:
             event
-            state_ids: The state at the event if we don't have all the event's prev events
+            context: The `EventContext` which we are about to persist the event with.
             origin: The host the event originates from.
         """
+        if await self._store.is_partial_state_room(event.room_id):
+            # We might not know the sender's membership in the current state, so don't
+            # soft fail anything. Even if we do have a membership for the sender in the
+            # current state, it may have been derived from state resolution between
+            # partial and full state and may not be accurate.
+            return
+
         extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
         extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
@@ -1620,11 +1893,15 @@ class FederationEventHandler:
         auth_types = auth_types_for_event(room_version_obj, event)
 
         # Calculate the "current state".
-        if state_ids is not None:
-            # If we're explicitly given the state then we won't have all the
-            # prev events, and so we have a gap in the graph. In this case
-            # we want to be a little careful as we might have been down for
-            # a while and have an incorrect view of the current state,
+        seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+        has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+        if has_missing_prevs:
+            # We don't have all the prev_events of this event, which means we have a
+            # gap in the graph, and the new event is going to become a new backwards
+            # extremity.
+            #
+            # In this case we want to be a little careful as we might have been
+            # down for a while and have an incorrect view of the current state,
             # however we still want to do checks as gaps are easy to
             # maliciously manufacture.
             #
@@ -1637,6 +1914,7 @@ class FederationEventHandler:
                 event.room_id, extrem_ids
             )
             state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_ids = await context.get_prev_state_ids()
             state_sets.append(state_ids)
             current_state_ids = (
                 await self._state_resolution_handler.resolve_events_with_store(
@@ -1669,7 +1947,7 @@ class FederationEventHandler:
         )
 
         try:
-            check_auth_rules_for_event(room_version_obj, event, current_auth_events)
+            check_state_dependent_auth_rules(event, current_auth_events)
         except AuthError as e:
             logger.warning(
                 "Soft-failing %r (from %s) because %s",
@@ -1685,95 +1963,8 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
-        self, destination: str, event: EventBase
+        self, destination: Optional[str], event: EventBase
     ) -> Collection[EventBase]:
         """Fetch this event's auth_events, from database or remote
 
@@ -1789,12 +1980,19 @@ class FederationEventHandler:
         Args:
             destination: where to send the /event_auth request. Typically the server
                that sent us `event` in the first place.
+
+               If this is None, no attempt is made to load any missing auth events:
+               rather, an AssertionError is raised if there are any missing events.
+
             event: the event whose auth_events we want
 
         Returns:
             all of the events listed in `event.auth_events_ids`, after deduplication
 
         Raises:
+            AssertionError if some auth events were missing and no `destination` was
+            supplied.
+
             AuthError if we were unable to fetch the auth_events for any reason.
         """
         event_auth_event_ids = set(event.auth_event_ids())
@@ -1806,6 +2004,13 @@ class FederationEventHandler:
         )
         if not missing_auth_event_ids:
             return event_auth_events.values()
+        if destination is None:
+            # this shouldn't happen: destination must be set unless we know we have already
+            # persisted the auth events.
+            raise AssertionError(
+                "_load_or_fetch_auth_events_for_event() called with no destination for "
+                "an event with missing auth_events"
+            )
 
         logger.info(
             "Event %s refers to unknown auth events %s: fetching auth chain",
@@ -1841,6 +2046,8 @@ class FederationEventHandler:
         # instead we raise an AuthError, which will make the caller ignore it.
         raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
 
+    @trace
+    @tag_args
     async def _get_remote_auth_chain_for_event(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1869,61 +2076,7 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
+    @trace
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
@@ -1933,6 +2086,9 @@ class FederationEventHandler:
             event: The event itself.
             context: The event context.
             backfilled: True if the event was backfilled.
+
+        PartialStateConflictError: if attempting to persist a partial state event in
+            a room that has been un-partial stated.
         """
         # this method should not be called on outliers (those code paths call
         # persist_events_and_notify directly.)
@@ -1963,9 +2119,7 @@ class FederationEventHandler:
                 event.room_id, [(event, context)], backfilled=backfilled
             )
         except Exception:
-            run_in_background(
-                self._store.remove_push_actions_from_staging, event.event_id
-            )
+            await self._store.remove_push_actions_from_staging(event.event_id)
             raise
 
     async def persist_events_and_notify(
@@ -1987,6 +2141,10 @@ class FederationEventHandler:
 
         Returns:
             The stream ID after which all events have been persisted.
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         if not event_and_contexts:
             return self._store.get_room_max_stream_ordering()
@@ -1995,14 +2153,19 @@ class FederationEventHandler:
         if instance != self._instance_name:
             # Limit the number of events sent over replication. We choose 200
             # here as that is what we default to in `max_request_body_size(..)`
-            for batch in batch_iter(event_and_contexts, 200):
-                result = await self._send_events(
-                    instance_name=instance,
-                    store=self._store,
-                    room_id=room_id,
-                    event_and_contexts=batch,
-                    backfilled=backfilled,
-                )
+            try:
+                for batch in batch_iter(event_and_contexts, 200):
+                    result = await self._send_events(
+                        instance_name=instance,
+                        store=self._store,
+                        room_id=room_id,
+                        event_and_contexts=batch,
+                        backfilled=backfilled,
+                    )
+            except SynapseError as e:
+                if e.code == HTTPStatus.CONFLICT:
+                    raise PartialStateConflictError()
+                raise
             return result["max_stream_id"]
         else:
             assert self._storage_controllers.persistence
@@ -2022,8 +2185,17 @@ class FederationEventHandler:
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event in events:
-                    await self._notify_persisted_event(event, max_stream_token)
+                with start_active_span("notify_persisted_events"):
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids",
+                        str([ev.event_id for ev in events]),
+                    )
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids.length",
+                        str(len(events)),
+                    )
+                    for event in events:
+                        await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
 
@@ -2064,6 +2236,10 @@ class FederationEventHandler:
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            # TODO retrieve the previous state, and exclude join -> join transitions
+            self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9bca2bc4b2..93d09e9939 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,7 +26,6 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.site import SynapseRequest
@@ -163,8 +162,7 @@ class IdentityHandler:
         sid: str,
         mxid: str,
         id_server: str,
-        id_access_token: Optional[str] = None,
-        use_v2: bool = True,
+        id_access_token: str,
     ) -> JsonDict:
         """Bind a 3PID to an identity server
 
@@ -174,8 +172,7 @@ class IdentityHandler:
             mxid: The MXID to bind the 3PID to
             id_server: The domain of the identity server to query
             id_access_token: The access token to authenticate to the identity
-                server with, if necessary. Required if use_v2 is true
-            use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+                server with
 
         Raises:
             SynapseError: On any of the following conditions
@@ -187,24 +184,15 @@ class IdentityHandler:
         """
         logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
 
-        # If an id_access_token is not supplied, force usage of v1
-        if id_access_token is None:
-            use_v2 = False
-
         if not valid_id_server_location(id_server):
             raise SynapseError(
                 400,
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        # Decide which API endpoint URLs to use
-        headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
-        if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
-            headers["Authorization"] = create_id_access_token_header(id_access_token)  # type: ignore
-        else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+        bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+        headers = {"Authorization": create_id_access_token_header(id_access_token)}
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -223,21 +211,14 @@ class IdentityHandler:
 
             return data
         except HttpResponseException as e:
-            if e.code != 404 or not use_v2:
-                logger.error("3PID bind failed with Matrix error: %r", e)
-                raise e.to_synapse_error()
+            logger.error("3PID bind failed with Matrix error: %r", e)
+            raise e.to_synapse_error()
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-        logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
-        res = await self.bind_threepid(
-            client_secret, sid, mxid, id_server, id_access_token, use_v2=False
-        )
-        return res
-
     async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
         """Attempt to remove a 3PID from an identity server, or if one is not provided, all
         identity servers we're aware the binding is present on
@@ -300,8 +281,8 @@ class IdentityHandler:
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
+        url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
+        url_bytes = b"/_matrix/identity/v2/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -434,48 +415,6 @@ class IdentityHandler:
 
         return session_id
 
-    async def requestEmailToken(
-        self,
-        id_server: str,
-        email: str,
-        client_secret: str,
-        send_attempt: int,
-        next_link: Optional[str] = None,
-    ) -> JsonDict:
-        """
-        Request an external server send an email on our behalf for the purposes of threepid
-        validation.
-
-        Args:
-            id_server: The identity server to proxy to
-            email: The email to send the message to
-            client_secret: The unique client_secret sends by the user
-            send_attempt: Which attempt this is
-            next_link: A link to redirect the user to once they submit the token
-
-        Returns:
-            The json response body from the server
-        """
-        params = {
-            "email": email,
-            "client_secret": client_secret,
-            "send_attempt": send_attempt,
-        }
-        if next_link:
-            params["next_link"] = next_link
-
-        try:
-            data = await self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
-                params,
-            )
-            return data
-        except HttpResponseException as e:
-            logger.info("Proxied requestToken failed: %r", e)
-            raise e.to_synapse_error()
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-
     async def requestMsisdnToken(
         self,
         id_server: str,
@@ -549,18 +488,7 @@ class IdentityHandler:
         validation_session = None
 
         # Try to validate as email
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            # Remote emails will only be used if a valid identity server is provided.
-            assert (
-                self.hs.config.registration.account_threepid_delegate_email is not None
-            )
-
-            # Ask our delegated email identity server
-            validation_session = await self.threepid_from_creds(
-                self.hs.config.registration.account_threepid_delegate_email,
-                threepid_creds,
-            )
-        elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             # Get a validated session matching these details
             validation_session = await self.store.get_threepid_validation_session(
                 "email", client_secret, sid=sid, validated=True
@@ -610,11 +538,7 @@ class IdentityHandler:
             raise SynapseError(400, "Error contacting the identity server")
 
     async def lookup_3pid(
-        self,
-        id_server: str,
-        medium: str,
-        address: str,
-        id_access_token: Optional[str] = None,
+        self, id_server: str, medium: str, address: str, id_access_token: str
     ) -> Optional[str]:
         """Looks up a 3pid in the passed identity server.
 
@@ -629,60 +553,15 @@ class IdentityHandler:
         Returns:
             the matrix ID of the 3pid, or None if it is not recognized.
         """
-        if id_access_token is not None:
-            try:
-                results = await self._lookup_3pid_v2(
-                    id_server, id_access_token, medium, address
-                )
-                return results
-
-            except Exception as e:
-                # Catch HttpResponseExcept for a non-200 response code
-                # Check if this identity server does not know about v2 lookups
-                if isinstance(e, HttpResponseException) and e.code == 404:
-                    # This is an old identity server that does not yet support v2 lookups
-                    logger.warning(
-                        "Attempted v2 lookup on v1 identity server %s. Falling "
-                        "back to v1",
-                        id_server,
-                    )
-                else:
-                    logger.warning("Error when looking up hashing details: %s", e)
-                    return None
-
-        return await self._lookup_3pid_v1(id_server, medium, address)
-
-    async def _lookup_3pid_v1(
-        self, id_server: str, medium: str, address: str
-    ) -> Optional[str]:
-        """Looks up a 3pid in the passed identity server using v1 lookup.
-
-        Args:
-            id_server: The server name (including port, if required)
-                of the identity server to use.
-            medium: The type of the third party identifier (e.g. "email").
-            address: The third party identifier (e.g. "foo@example.com").
 
-        Returns:
-            the matrix ID of the 3pid, or None if it is not recognized.
-        """
         try:
-            data = await self.blacklisting_http_client.get_json(
-                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
-                {"medium": medium, "address": address},
+            results = await self._lookup_3pid_v2(
+                id_server, id_access_token, medium, address
             )
-
-            if "mxid" in data:
-                # note: we used to verify the identity server's signature here, but no longer
-                # require or validate it. See the following for context:
-                # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
-                return data["mxid"]
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-        except OSError as e:
-            logger.warning("Error from v1 identity server lookup: %s" % (e,))
-
-        return None
+            return results
+        except Exception as e:
+            logger.warning("Error when looking up hashing details: %s", e)
+            return None
 
     async def _lookup_3pid_v2(
         self, id_server: str, id_access_token: str, medium: str, address: str
@@ -811,7 +690,7 @@ class IdentityHandler:
         room_type: Optional[str],
         inviter_display_name: str,
         inviter_avatar_url: str,
-        id_access_token: Optional[str] = None,
+        id_access_token: str,
     ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
         """
         Asks an identity server for a third party invite.
@@ -832,7 +711,7 @@ class IdentityHandler:
             inviter_display_name: The current display name of the
                 inviter.
             inviter_avatar_url: The URL of the inviter's avatar.
-            id_access_token (str|None): The access token to authenticate to the identity
+            id_access_token (str): The access token to authenticate to the identity
                 server with
 
         Returns:
@@ -864,71 +743,24 @@ class IdentityHandler:
             invite_config["org.matrix.web_client_location"] = self._web_client_location
 
         # Add the identity service access token to the JSON body and use the v2
-        # Identity Service endpoints if id_access_token is present
+        # Identity Service endpoints
         data = None
-        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
 
-        if id_access_token:
-            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
-            )
+        key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
+            id_server_scheme,
+            id_server,
+        )
 
-            # Attempt a v2 lookup
-            url = base_url + "/v2/store-invite"
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url,
-                    invite_config,
-                    {"Authorization": create_id_access_token_header(id_access_token)},
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                if e.code != 404:
-                    logger.info("Failed to POST %s with JSON: %s", url, e)
-                    raise e
-
-        if data is None:
-            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+        url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
+        try:
+            data = await self.blacklisting_http_client.post_json_get_json(
+                url,
+                invite_config,
+                {"Authorization": create_id_access_token_header(id_access_token)},
             )
-            url = base_url + "/api/v1/store-invite"
-
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url, invite_config
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                logger.warning(
-                    "Error trying to call /store-invite on %s%s: %s",
-                    id_server_scheme,
-                    id_server,
-                    e,
-                )
-
-            if data is None:
-                # Some identity servers may only support application/x-www-form-urlencoded
-                # types. This is especially true with old instances of Sydent, see
-                # https://github.com/matrix-org/sydent/pull/170
-                try:
-                    data = await self.blacklisting_http_client.post_urlencoded_get_json(
-                        url, invite_config
-                    )
-                except HttpResponseException as e:
-                    logger.warning(
-                        "Error calling /store-invite on %s%s with fallback "
-                        "encoding: %s",
-                        id_server_scheme,
-                        id_server,
-                        e,
-                    )
-                    raise e
-
-        # TODO: Check for success
+        except RequestTimedOutError:
+            raise SynapseError(500, "Timed out contacting identity server")
+
         token = data["token"]
         public_keys = data.get("public_keys", [])
         if "public_key" in data:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 85b472f250..860c82c110 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -143,8 +143,8 @@ class InitialSyncHandler:
             joined_rooms,
             to_key=int(now_token.receipt_key),
         )
-        if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
+
+        receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -309,18 +309,18 @@ class InitialSyncHandler:
         if blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
-        user_id = requester.user.to_string()
-
         (
             membership,
             member_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
             room_id,
-            user_id,
+            requester,
             allow_departed_users=True,
         )
         is_peeking = member_event_id is None
 
+        user_id = requester.user.to_string()
+
         if membership == Membership.JOIN:
             result = await self._room_initial_sync_joined(
                 user_id, room_id, pagin_config, membership, is_peeking
@@ -456,11 +456,8 @@ class InitialSyncHandler:
             )
             if not receipts:
                 return []
-            if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private_receipts(
-                    receipts, user_id
-                )
-            return receipts
+
+            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f455158a2c..72157d5a36 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -37,12 +37,14 @@ from synapse.api.errors import (
     AuthError,
     Codes,
     ConsentNotGivenError,
+    LimitExceededError,
     NotFoundError,
     ShadowBanError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
@@ -50,9 +52,11 @@ from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
+from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -100,7 +104,7 @@ class MessageHandler:
 
     async def get_room_data(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         event_type: str,
         state_key: str,
@@ -108,7 +112,7 @@ class MessageHandler:
         """Get data from a room.
 
         Args:
-            user_id
+            requester: The user who did the request.
             room_id
             event_type
             state_key
@@ -121,7 +125,7 @@ class MessageHandler:
             membership,
             membership_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         if membership == Membership.JOIN:
@@ -147,17 +151,20 @@ class MessageHandler:
                 "Attempted to retrieve data from a room for a user that has never been in it. "
                 "This should not have happened."
             )
-            raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
+            raise UnstableSpecAuthError(
+                403,
+                "User not in room",
+                errcode=Codes.NOT_JOINED,
+            )
 
         return data
 
     async def get_state_events(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         state_filter: Optional[StateFilter] = None,
         at_token: Optional[StreamToken] = None,
-        is_guest: bool = False,
     ) -> List[dict]:
         """Retrieve all state events for a given room. If the user is
         joined to the room then return the current state. If the user has
@@ -166,14 +173,13 @@ class MessageHandler:
         visible.
 
         Args:
-            user_id: The user requesting state events.
+            requester: The user requesting state events.
             room_id: The room ID to get all state events from.
             state_filter: The state filter used to fetch state from the database.
             at_token: the stream token of the at which we are requesting
                 the stats. If the user is not allowed to view the state as of that
                 stream token, we raise a 403 SynapseError. If None, returns the current
                 state based on the current_state_events table.
-            is_guest: whether this user is a guest
         Returns:
             A list of dicts representing state events. [{}, {}, {}]
         Raises:
@@ -183,6 +189,7 @@ class MessageHandler:
             members of this room.
         """
         state_filter = state_filter or StateFilter.all()
+        user_id = requester.user.to_string()
 
         if at_token:
             last_event_id = (
@@ -215,7 +222,7 @@ class MessageHandler:
                 membership,
                 membership_event_id,
             ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
 
             if membership == Membership.JOIN:
@@ -309,30 +316,42 @@ class MessageHandler:
         Returns:
             A dict of user_id to profile info
         """
-        user_id = requester.user.to_string()
         if not requester.app_service:
             # We check AS auth after fetching the room membership, as it
             # requires us to pull out all joined members anyway.
             membership, _ = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
             if membership != Membership.JOIN:
-                raise NotImplementedError(
-                    "Getting joined members after leaving is not implemented"
+                raise SynapseError(
+                    code=403,
+                    errcode=Codes.FORBIDDEN,
+                    msg="Getting joined members while not being a current member of the room is forbidden.",
                 )
 
-        users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
+        users_with_profile = (
+            await self._state_storage_controller.get_users_in_room_with_profiles(
+                room_id
+            )
+        )
 
         # If this is an AS, double check that they are allowed to see the members.
         # This can either be because the AS user is in the room or because there
         # is a user in the room that the AS is "interested in"
-        if requester.app_service and user_id not in users_with_profile:
+        if (
+            requester.app_service
+            and requester.user.to_string() not in users_with_profile
+        ):
             for uid in users_with_profile:
                 if requester.app_service.is_interested_in_user(uid):
                     break
             else:
                 # Loop fell through, AS has no interested users in room
-                raise AuthError(403, "Appservice not in room")
+                raise UnstableSpecAuthError(
+                    403,
+                    "Appservice not in room",
+                    errcode=Codes.NOT_JOINED,
+                )
 
         return {
             user_id: {
@@ -444,7 +463,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
 class EventCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._event_auth_handler = hs.get_event_auth_handler()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
@@ -461,6 +480,7 @@ class EventCreationHandler:
         )
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -605,7 +625,7 @@ class EventCreationHandler:
         Returns:
             Tuple of created event, Context
         """
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version_id = event_dict["content"]["room_version"]
@@ -741,8 +761,10 @@ class EventCreationHandler:
     async def _is_server_notices_room(self, room_id: str) -> bool:
         if self.config.servernotices.server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self.config.servernotices.server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
     async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
@@ -903,6 +925,9 @@ class EventCreationHandler:
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
+        if ratelimit:
+            await self.request_ratelimiter.ratelimit(requester, update=False)
+
         # We limit the number of concurrent event sends in a room so that we
         # don't fork the DAG too much. If we don't limit then we can end up in
         # a situation where event persistence can't keep up, causing
@@ -954,14 +979,12 @@ class EventCreationHandler:
                             "Spam-check module returned invalid error value. Expecting [code, dict], got %s",
                             spam_check_result,
                         )
-                        spam_check_result = Codes.FORBIDDEN
 
-                if isinstance(spam_check_result, Codes):
-                    raise SynapseError(
-                        403,
-                        "This message has been rejected as probable spam",
-                        spam_check_result,
-                    )
+                        raise SynapseError(
+                            403,
+                            "This message has been rejected as probable spam",
+                            Codes.FORBIDDEN,
+                        )
 
                 # Backwards compatibility: if the return value is not an error code, it
                 # means the module returned an error message to be included in the
@@ -1102,6 +1125,7 @@ class EventCreationHandler:
             #
             # TODO(faster_joins): figure out how this works, and make sure that the
             #   old state is complete.
+            #   https://github.com/matrix-org/synapse/issues/13003
             metadata = await self.store.get_metadata_for_events(state_event_ids)
 
             state_map_for_event: MutableStateMap[str] = {}
@@ -1130,6 +1154,10 @@ class EventCreationHandler:
             context = await self.state.compute_event_context(
                 event,
                 state_ids_before_event=state_map_for_event,
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                partial_state=False,
             )
         else:
             context = await self.state.compute_event_context(event)
@@ -1248,6 +1276,8 @@ class EventCreationHandler:
 
         Raises:
             ShadowBanError if the requester has been shadow-banned.
+            SynapseError(503) if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         extra_users = extra_users or []
 
@@ -1273,23 +1303,6 @@ class EventCreationHandler:
                 )
                 return prev_event
 
-        if event.is_state() and (event.type, event.state_key) == (
-            EventTypes.Create,
-            "",
-        ):
-            room_version_id = event.content.get(
-                "room_version", RoomVersions.V1.identifier
-            )
-            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not maybe_room_version_obj:
-                raise UnsupportedRoomVersionError(
-                    "Attempt to create a room with unsupported room version %s"
-                    % (room_version_id,)
-                )
-            room_version_obj = maybe_room_version_obj
-        else:
-            room_version_obj = await self.store.get_room_version(event.room_id)
-
         if event.internal_metadata.is_out_of_band_membership():
             # the only sort of out-of-band-membership events we expect to see here are
             # invite rejections and rescinded knocks that we have generated ourselves.
@@ -1297,9 +1310,9 @@ class EventCreationHandler:
             assert event.content["membership"] == Membership.LEAVE
         else:
             try:
-                validate_event_for_room_version(room_version_obj, event)
+                validate_event_for_room_version(event)
                 await self._event_auth_handler.check_auth_rules_from_context(
-                    room_version_obj, event, context
+                    event, context
                 )
             except AuthError as err:
                 logger.warning("Denying new event %r because %s", event, err)
@@ -1315,24 +1328,35 @@ class EventCreationHandler:
 
         # We now persist the event (and update the cache in parallel, since we
         # don't want to block on it).
-        result, _ = await make_deferred_yieldable(
-            gather_results(
-                (
-                    run_in_background(
-                        self._persist_event,
-                        requester=requester,
-                        event=event,
-                        context=context,
-                        ratelimit=ratelimit,
-                        extra_users=extra_users,
+        try:
+            result, _ = await make_deferred_yieldable(
+                gather_results(
+                    (
+                        run_in_background(
+                            self._persist_event,
+                            requester=requester,
+                            event=event,
+                            context=context,
+                            ratelimit=ratelimit,
+                            extra_users=extra_users,
+                        ),
+                        run_in_background(
+                            self.cache_joined_hosts_for_event, event, context
+                        ).addErrback(
+                            log_failure, "cache_joined_hosts_for_event failed"
+                        ),
                     ),
-                    run_in_background(
-                        self.cache_joined_hosts_for_event, event, context
-                    ).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
-                ),
-                consumeErrors=True,
+                    consumeErrors=True,
+                )
+            ).addErrback(unwrapFirstError)
+        except PartialStateConflictError as e:
+            # The event context needs to be recomputed.
+            # Turn the error into a 429, as a hint to the client to try again.
+            logger.info(
+                "Room %s was un-partial stated while persisting client event.",
+                event.room_id,
             )
-        ).addErrback(unwrapFirstError)
+            raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
 
         return result
 
@@ -1347,6 +1371,9 @@ class EventCreationHandler:
         """Actually persists the event. Should only be called by
         `handle_new_client_event`, and see its docstring for documentation of
         the arguments.
+
+        PartialStateConflictError: if attempting to persist a partial state event in
+            a room that has been un-partial stated.
         """
 
         # Skip push notification actions for historical messages
@@ -1355,24 +1382,30 @@ class EventCreationHandler:
         # and `state_groups` because they have `prev_events` that aren't persisted yet
         # (historical messages persisted in reverse-chronological order).
         if not event.internal_metadata.is_historical():
-            await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                event, context
-            )
+            with opentracing.start_active_span("calculate_push_actions"):
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                    event, context
+                )
 
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
             if writer_instance != self._instance_name:
-                result = await self.send_event(
-                    instance_name=writer_instance,
-                    event_id=event.event_id,
-                    store=self.store,
-                    requester=requester,
-                    event=event,
-                    context=context,
-                    ratelimit=ratelimit,
-                    extra_users=extra_users,
-                )
+                try:
+                    result = await self.send_event(
+                        instance_name=writer_instance,
+                        event_id=event.event_id,
+                        store=self.store,
+                        requester=requester,
+                        event=event,
+                        context=context,
+                        ratelimit=ratelimit,
+                        extra_users=extra_users,
+                    )
+                except SynapseError as e:
+                    if e.code == HTTPStatus.CONFLICT:
+                        raise PartialStateConflictError()
+                    raise
                 stream_id = result["stream_id"]
                 event_id = result["event_id"]
                 if event_id != event.event_id:
@@ -1436,7 +1469,13 @@ class EventCreationHandler:
             if state_entry.state_group in self._external_cache_joined_hosts_updates:
                 return
 
-            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+            state = await state_entry.get_state(
+                self._storage_controllers.state, StateFilter.all()
+            )
+            with opentracing.start_active_span("get_joined_hosts"):
+                joined_hosts = await self.store.get_joined_hosts(
+                    event.room_id, state, state_entry
+                )
 
             # Note that the expiry times must be larger than the expiry time in
             # _external_cache_joined_hosts_updates.
@@ -1500,6 +1539,10 @@ class EventCreationHandler:
             The persisted event. This may be different than the given event if
             it was de-duplicated (e.g. because we had already persisted an
             event with the same transaction ID.)
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         extra_users = extra_users or []
 
@@ -1533,6 +1576,16 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            (
+                current_membership,
+                _,
+            ) = await self.store.get_local_current_membership_for_user_in_room(
+                event.state_key, event.room_id
+            )
+            if current_membership != Membership.JOIN:
+                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
         await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
@@ -1832,13 +1885,8 @@ class EventCreationHandler:
 
         # For each room we need to find a joined member we can use to send
         # the dummy event with.
-        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-        members = await self.state.get_current_users_in_room(
-            room_id, latest_event_ids=latest_event_ids
-        )
+        members = await self.store.get_local_users_in_room(room_id)
         for user_id in members:
-            if not self.hs.is_mine_id(user_id):
-                continue
             requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
@@ -1849,7 +1897,6 @@ class EventCreationHandler:
                         "room_id": room_id,
                         "sender": user_id,
                     },
-                    prev_event_ids=latest_event_ids,
                 )
 
                 event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 9de61d554f..d7a8226900 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, U
 from urllib.parse import urlencode, urlparse
 
 import attr
-import pymacaroons
 from authlib.common.security import generate_token
 from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
@@ -44,7 +43,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import Clock, json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
-from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
+from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
 from synapse.util.templates import _localpart_from_email_filter
 
 if TYPE_CHECKING:
@@ -105,9 +104,10 @@ class OidcHandler:
         # we should not have been instantiated if there is no configured provider.
         assert provider_confs
 
-        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._macaroon_generator = hs.get_macaroon_generator()
         self._providers: Dict[str, "OidcProvider"] = {
-            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+            p.idp_id: OidcProvider(hs, self._macaroon_generator, p)
+            for p in provider_confs
         }
 
     async def load_metadata(self) -> None:
@@ -216,7 +216,7 @@ class OidcHandler:
 
         # Deserialize the session token and verify it.
         try:
-            session_data = self._token_generator.verify_oidc_session_token(
+            session_data = self._macaroon_generator.verify_oidc_session_token(
                 session, state
             )
         except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
@@ -271,12 +271,12 @@ class OidcProvider:
     def __init__(
         self,
         hs: "HomeServer",
-        token_generator: "OidcSessionTokenGenerator",
+        macaroon_generator: MacaroonGenerator,
         provider: OidcProviderConfig,
     ):
         self._store = hs.get_datastores().main
 
-        self._token_generator = token_generator
+        self._macaroon_generaton = macaroon_generator
 
         self._config = provider
         self._callback_url: str = hs.config.oidc.oidc_callback_url
@@ -761,7 +761,7 @@ class OidcProvider:
         if not client_redirect_url:
             client_redirect_url = b""
 
-        cookie = self._token_generator.generate_oidc_session_token(
+        cookie = self._macaroon_generaton.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
                 idp_id=self.idp_id,
@@ -1112,121 +1112,6 @@ class JwtClientSecret:
         return self._cached_secret
 
 
-class OidcSessionTokenGenerator:
-    """Methods for generating and checking OIDC Session cookies."""
-
-    def __init__(self, hs: "HomeServer"):
-        self._clock = hs.get_clock()
-        self._server_name = hs.hostname
-        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
-
-    def generate_oidc_session_token(
-        self,
-        state: str,
-        session_data: "OidcSessionData",
-        duration_in_ms: int = (60 * 60 * 1000),
-    ) -> str:
-        """Generates a signed token storing data about an OIDC session.
-
-        When Synapse initiates an authorization flow, it creates a random state
-        and a random nonce. Those parameters are given to the provider and
-        should be verified when the client comes back from the provider.
-        It is also used to store the client_redirect_url, which is used to
-        complete the SSO login flow.
-
-        Args:
-            state: The ``state`` parameter passed to the OIDC provider.
-            session_data: data to include in the session token.
-            duration_in_ms: An optional duration for the token in milliseconds.
-                Defaults to an hour.
-
-        Returns:
-            A signed macaroon token with the session information.
-        """
-        macaroon = pymacaroons.Macaroon(
-            location=self._server_name,
-            identifier="key",
-            key=self._macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("type = session")
-        macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
-        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
-        macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (session_data.client_redirect_url,)
-        )
-        macaroon.add_first_party_caveat(
-            "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
-        )
-        now = self._clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
-    def verify_oidc_session_token(
-        self, session: bytes, state: str
-    ) -> "OidcSessionData":
-        """Verifies and extract an OIDC session token.
-
-        This verifies that a given session token was issued by this homeserver
-        and extract the nonce and client_redirect_url caveats.
-
-        Args:
-            session: The session token to verify
-            state: The state the OIDC provider gave back
-
-        Returns:
-            The data extracted from the session cookie
-
-        Raises:
-            KeyError if an expected caveat is missing from the macaroon.
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(session)
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = session")
-        v.satisfy_exact("state = %s" % (state,))
-        v.satisfy_general(lambda c: c.startswith("nonce = "))
-        v.satisfy_general(lambda c: c.startswith("idp_id = "))
-        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        satisfy_expiry(v, self._clock.time_msec)
-
-        v.verify(macaroon, self._macaroon_secret_key)
-
-        # Extract the session data from the token.
-        nonce = get_value_from_macaroon(macaroon, "nonce")
-        idp_id = get_value_from_macaroon(macaroon, "idp_id")
-        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
-        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
-        return OidcSessionData(
-            nonce=nonce,
-            idp_id=idp_id,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=ui_auth_session_id,
-        )
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class OidcSessionData:
-    """The attributes which are stored in a OIDC session cookie"""
-
-    # the Identity Provider being used
-    idp_id: str
-
-    # The `nonce` parameter passed to the OIDC provider.
-    nonce: str
-
-    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
-    client_redirect_url: str
-
-    # The session ID of the ongoing UI Auth ("" if this is a login)
-    ui_auth_session_id: str
-
-
 class UserAttributeDict(TypedDict):
     localpart: Optional[str]
     confirm_localpart: bool
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6262a35822..a0c39778ab 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.room import ShutdownRoomResponse
+from synapse.logging.opentracing import trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
@@ -158,11 +159,9 @@ class PaginationHandler:
         self._retention_allowed_lifetime_max = (
             hs.config.retention.retention_allowed_lifetime_max
         )
+        self._is_master = hs.config.worker.worker_app is None
 
-        if (
-            hs.config.worker.run_background_tasks
-            and hs.config.retention.retention_enabled
-        ):
+        if hs.config.retention.retention_enabled and self._is_master:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
@@ -416,6 +415,7 @@ class PaginationHandler:
 
             await self._storage_controllers.purge_events.purge_room(room_id)
 
+    @trace
     async def get_messages(
         self,
         requester: Requester,
@@ -462,7 +462,7 @@ class PaginationHandler:
                 membership,
                 member_event_id,
             ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
 
             if pagin_config.direction == "b":
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..4e575ffbaa 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    FrozenSet,
     Generator,
     Iterable,
     List,
@@ -42,7 +41,6 @@ from typing import (
     Set,
     Tuple,
     Type,
-    Union,
 )
 
 from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
 from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 # doesn't return. C.f. #5503.
                 return [], max_token
 
-            # Figure out which other users this user should receive updates for
-            users_interested_in = await self._get_interested_in(user, explicit_room_id)
+            # Figure out which other users this user should explicitly receive
+            # updates for
+            additional_users_interested_in = (
+                await self.get_presence_router().get_interested_users(user.to_string())
+            )
 
             # We have a set of users that we're interested in the presence of. We want to
             # cross-reference that with the users that have actually changed their presence.
 
             # Check whether this user should see all user updates
 
-            if users_interested_in == PresenceRouter.ALL_USERS:
+            if additional_users_interested_in == PresenceRouter.ALL_USERS:
                 # Provide presence state for all users
                 presence_updates = await self._filter_all_presence_updates_for_user(
                     user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 return presence_updates, max_token
 
             # Make mypy happy. users_interested_in should now be a set
-            assert not isinstance(users_interested_in, str)
+            assert not isinstance(additional_users_interested_in, str)
+
+            # We always care about our own presence.
+            additional_users_interested_in.add(user_id)
+
+            if explicit_room_id:
+                user_ids = await self.store.get_users_in_room(explicit_room_id)
+                additional_users_interested_in.update(user_ids)
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+            interested_and_updated_users: Collection[str]
 
             if from_key is not None:
                 # First get all users that have had a presence update
                 updated_users = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                # Use a slightly-optimised method for processing smaller sets of updates.
-                if updated_users is not None and len(updated_users) < 500:
-                    # For small deltas, it's quicker to get all changes and then
-                    # cross-reference with the users we're interested in
+                if updated_users is not None:
+                    # If we have the full list of changes for presence we can
+                    # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
-                    for other_user_id in updated_users:
-                        if other_user_id in users_interested_in:
-                            # mypy thinks this variable could be a FrozenSet as it's possibly set
-                            # to one in the `get_entities_changed` call below, and `add()` is not
-                            # method on a FrozenSet. That doesn't affect us here though, as
-                            # `interested_and_updated_users` is clearly a set() above.
-                            interested_and_updated_users.add(other_user_id)  # type: ignore
+
+                    sharing_users = await self.store.do_users_share_a_room(
+                        user_id, updated_users
+                    )
+
+                    interested_and_updated_users = (
+                        sharing_users.union(additional_users_interested_in)
+                    ).intersection(updated_users)
+
                 else:
                     # Too many possible updates. Find all users we can see and check
                     # if any of them have changed.
                     get_updates_counter.labels("full").inc()
 
+                    users_interested_in = (
+                        await self.store.get_users_who_share_room_with_user(user_id)
+                    )
+                    users_interested_in.update(additional_users_interested_in)
+
                     interested_and_updated_users = (
                         stream_change_cache.get_entities_changed(
                             users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
             else:
                 # No from_key has been specified. Return the presence for all users
                 # this user is interested in
-                interested_and_updated_users = users_interested_in
+                interested_and_updated_users = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+                interested_and_updated_users.update(additional_users_interested_in)
 
             # Retrieve the current presence state for each user
             users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
     def get_current_key(self) -> int:
         return self.store.get_current_presence_token()
 
-    @cached(num_args=2, cache_context=True)
-    async def _get_interested_in(
-        self,
-        user: UserID,
-        explicit_room_id: Optional[str] = None,
-        cache_context: Optional[_CacheContext] = None,
-    ) -> Union[Set[str], str]:
-        """Returns the set of users that the given user should see presence
-        updates for.
-
-        Args:
-            user: The user to retrieve presence updates for.
-            explicit_room_id: The users that are in the room will be returned.
-
-        Returns:
-            A set of user IDs to return presence updates for, or "ALL" to return all
-            known updates.
-        """
-        user_id = user.to_string()
-        users_interested_in = set()
-        users_interested_in.add(user_id)  # So that we receive our own presence
-
-        # cache_context isn't likely to ever be None due to the @cached decorator,
-        # but we can't have a non-optional argument after the optional argument
-        # explicit_room_id either. Assert cache_context is not None so we can use it
-        # without mypy complaining.
-        assert cache_context
-
-        # Check with the presence router whether we should poll additional users for
-        # their presence information
-        additional_users = await self.get_presence_router().get_interested_users(
-            user.to_string()
-        )
-        if additional_users == PresenceRouter.ALL_USERS:
-            # If the module requested that this user see the presence updates of *all*
-            # users, then simply return that instead of calculating what rooms this
-            # user shares
-            return PresenceRouter.ALL_USERS
-
-        # Add the additional users from the router
-        users_interested_in.update(additional_users)
-
-        # Find the users who share a room with this user
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-        users_interested_in.update(users_who_share_room)
-
-        if explicit_room_id:
-            user_ids = await self.store.get_users_in_room(
-                explicit_room_id, on_invalidate=cache_context.invalidate
-            )
-            users_interested_in.update(user_ids)
-
-        return users_interested_in
-
 
 def handle_timeouts(
     user_states: List[UserPresenceState],
@@ -2091,8 +2051,7 @@ async def get_interested_remotes(
     )
 
     for room_id, states in room_ids_to_states.items():
-        user_ids = await store.get_users_in_room(room_id)
-        hosts = {get_domain_from_id(user_id) for user_id in user_ids}
+        hosts = await store.get_current_hosts_in_room(room_id)
         for host in hosts:
             hosts_and_states.setdefault(host, set()).update(states)
 
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6eed3826a7..d8ff5289b5 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -67,19 +67,14 @@ class ProfileHandler:
         target_user = UserID.from_string(user_id)
 
         if self.hs.is_mine(target_user):
-            try:
-                displayname = await self.store.get_profile_displayname(
-                    target_user.localpart
-                )
-                avatar_url = await self.store.get_profile_avatar_url(
-                    target_user.localpart
-                )
-            except StoreError as e:
-                if e.code == 404:
-                    raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
-                raise
+            profileinfo = await self.store.get_profileinfo(target_user.localpart)
+            if profileinfo.display_name is None:
+                raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
 
-            return {"displayname": displayname, "avatar_url": avatar_url}
+            return {
+                "displayname": profileinfo.display_name,
+                "avatar_url": profileinfo.avatar_url,
+            }
         else:
             try:
                 result = await self.federation.make_query(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d2882b0a..d2bdb9c8be 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -256,10 +256,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             room_ids, from_key=from_key, to_key=to_key
         )
 
-        if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private_receipts(
-                events, user.to_string()
-            )
+        events = ReceiptEventSource.filter_out_private_receipts(
+            events, user.to_string()
+        )
 
         return events, to_key
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 338204287f..20ec22105a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -29,7 +29,13 @@ from synapse.api.constants import (
     JoinRules,
     LoginType,
 )
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    ConsentNotGivenError,
+    InvalidClientTokenError,
+    SynapseError,
+)
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
@@ -91,6 +97,7 @@ class RegistrationHandler:
         self.clock = hs.get_clock()
         self.hs = hs
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
@@ -179,10 +186,7 @@ class RegistrationHandler:
                 )
             if guest_access_token:
                 user_data = await self.auth.get_user_by_access_token(guest_access_token)
-                if (
-                    not user_data.is_guest
-                    or UserID.from_string(user_data.user_id).localpart != localpart
-                ):
+                if not user_data.is_guest or user_data.user.localpart != localpart:
                     raise AuthError(
                         403,
                         "Cannot register taken user ID without valid guest "
@@ -276,7 +280,7 @@ class RegistrationHandler:
 
         # do not check_auth_blocking if the call is coming through the Admin API
         if not by_admin:
-            await self.auth.check_auth_blocking(threepid=threepid)
+            await self.auth_blocking.check_auth_blocking(threepid=threepid)
 
         if localpart is not None:
             await self.check_username(localpart, guest_access_token=guest_access_token)
@@ -617,7 +621,7 @@ class RegistrationHandler:
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
         if not service:
-            raise AuthError(403, "Invalid application service token.")
+            raise InvalidClientTokenError()
         if not service.is_interested_in_user(user_id):
             raise SynapseError(
                 400,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0b63cd2186..28d7093f08 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -19,6 +19,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -73,7 +74,6 @@ class RelationsHandler:
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -89,7 +89,6 @@ class RelationsHandler:
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -104,7 +103,7 @@ class RelationsHandler:
 
         # TODO Properly handle a user leaving a room.
         (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         # This gets the original event and checks that a) the event exists and
@@ -122,7 +121,6 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            aggregation_key=aggregation_key,
             limit=limit,
             direction=direction,
             from_token=from_token,
@@ -364,6 +362,7 @@ class RelationsHandler:
 
         return results
 
+    @trace
     async def get_bundled_aggregations(
         self, events: Iterable[EventBase], user_id: str
     ) -> Dict[str, BundledAggregations]:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 520663f172..33e9a87002 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -19,6 +19,7 @@ import math
 import random
 import string
 from collections import OrderedDict
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -60,8 +61,8 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
 from synapse.handlers.relations import BundledAggregations
+from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
@@ -109,6 +110,7 @@ class RoomCreationHandler:
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.hs = hs
         self.spam_checker = hs.get_spam_checker()
@@ -226,10 +228,9 @@ class RoomCreationHandler:
                 },
             },
         )
-        old_room_version = await self.store.get_room_version(old_room_id)
-        validate_event_for_room_version(old_room_version, tombstone_event)
+        validate_event_for_room_version(tombstone_event)
         await self._event_auth_handler.check_auth_rules_from_context(
-            old_room_version, tombstone_event, tombstone_context
+            tombstone_event, tombstone_context
         )
 
         # Upgrade the room
@@ -437,9 +438,13 @@ class RoomCreationHandler:
         """
         user_id = requester.user.to_string()
 
-        if not await self.spam_checker.user_may_create_room(user_id):
+        spam_check = await self.spam_checker.user_may_create_room(user_id)
+        if spam_check != NOT_SPAM:
             raise SynapseError(
-                403, "You are not permitted to create rooms", Codes.FORBIDDEN
+                403,
+                "You are not permitted to create rooms",
+                errcode=spam_check[0],
+                additional_fields=spam_check[1],
             )
 
         creation_content: JsonDict = {
@@ -700,14 +705,14 @@ class RoomCreationHandler:
                 was, requested, `room_alias`. Secondly, the stream_id of the
                 last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, or something went
-            horribly wrong.
+            SynapseError if the room ID couldn't be stored, 3pid invitation config
+            validation failed, or something went horribly wrong.
             ResourceLimitError if server is blocked to some resource being
             exceeded
         """
         user_id = requester.user.to_string()
 
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         if (
             self._server_notices_mxid is not None
@@ -716,7 +721,7 @@ class RoomCreationHandler:
             # allow the server notices mxid to create rooms
             is_requester_admin = True
         else:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
 
         # Let the third party rules modify the room creation config if needed, or abort
         # the room creation entirely with an exception.
@@ -727,12 +732,28 @@ class RoomCreationHandler:
         invite_3pid_list = config.get("invite_3pid", [])
         invite_list = config.get("invite", [])
 
-        if not is_requester_admin and not (
-            await self.spam_checker.user_may_create_room(user_id)
-        ):
-            raise SynapseError(
-                403, "You are not permitted to create rooms", Codes.FORBIDDEN
-            )
+        # validate each entry for correctness
+        for invite_3pid in invite_3pid_list:
+            if not all(
+                key in invite_3pid
+                for key in ("medium", "address", "id_server", "id_access_token")
+            ):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "all of `medium`, `address`, `id_server` and `id_access_token` "
+                    "are required when making a 3pid invite",
+                    Codes.MISSING_PARAM,
+                )
+
+        if not is_requester_admin:
+            spam_check = await self.spam_checker.user_may_create_room(user_id)
+            if spam_check != NOT_SPAM:
+                raise SynapseError(
+                    403,
+                    "You are not permitted to create rooms",
+                    errcode=spam_check[0],
+                    additional_fields=spam_check[1],
+                )
 
         if ratelimit:
             await self.request_ratelimiter.ratelimit(requester)
@@ -881,7 +902,11 @@ class RoomCreationHandler:
         # override any attempt to set room versions via the creation_content
         creation_content["room_version"] = room_version.identifier
 
-        last_stream_id = await self._send_events_for_new_room(
+        (
+            last_stream_id,
+            last_sent_event_id,
+            depth,
+        ) = await self._send_events_for_new_room(
             requester,
             room_id,
             preset_config=preset_config,
@@ -897,7 +922,7 @@ class RoomCreationHandler:
         if "name" in config:
             name = config["name"]
             (
-                _,
+                name_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -909,12 +934,16 @@ class RoomCreationHandler:
                     "content": {"name": name},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = name_event.event_id
+            depth += 1
 
         if "topic" in config:
             topic = config["topic"]
             (
-                _,
+                topic_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -926,7 +955,11 @@ class RoomCreationHandler:
                     "content": {"topic": topic},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = topic_event.event_id
+            depth += 1
 
         # we avoid dropping the lock between invites, as otherwise joins can
         # start coming in and making the createRoom slow.
@@ -941,7 +974,7 @@ class RoomCreationHandler:
 
             for invitee in invite_list:
                 (
-                    _,
+                    member_event_id,
                     last_stream_id,
                 ) = await self.room_member_handler.update_membership_locked(
                     requester,
@@ -951,16 +984,23 @@ class RoomCreationHandler:
                     ratelimit=False,
                     content=content,
                     new_room=True,
+                    prev_event_ids=[last_sent_event_id],
+                    depth=depth,
                 )
+                last_sent_event_id = member_event_id
+                depth += 1
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
-            id_access_token = invite_3pid.get("id_access_token")  # optional
+            id_access_token = invite_3pid["id_access_token"]
             address = invite_3pid["address"]
             medium = invite_3pid["medium"]
             # Note that do_3pid_invite can raise a  ShadowBanError, but this was
             # handled above by emptying invite_3pid_list.
-            last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
+            (
+                member_event_id,
+                last_stream_id,
+            ) = await self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
                 medium,
@@ -969,7 +1009,11 @@ class RoomCreationHandler:
                 requester,
                 txn_id=None,
                 id_access_token=id_access_token,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = member_event_id
+            depth += 1
 
         result = {"room_id": room_id}
 
@@ -997,20 +1041,24 @@ class RoomCreationHandler:
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
         ratelimit: bool = True,
-    ) -> int:
+    ) -> Tuple[int, str, int]:
         """Sends the initial events into a new room.
 
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
         Returns:
-            The stream_id of the last event persisted.
+            A tuple containing the stream ID, event ID and depth of the last
+            event sent to the room.
         """
 
         creator_id = creator.user.to_string()
 
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
 
+        depth = 1
+        last_sent_event_id: Optional[str] = None
+
         def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
 
@@ -1020,19 +1068,30 @@ class RoomCreationHandler:
             return e
 
         async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
+            nonlocal last_sent_event_id
+            nonlocal depth
+
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
             # Allow these events to be sent even if the user is shadow-banned to
             # allow the room creation to complete.
             (
-                _,
+                sent_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 creator,
                 event,
                 ratelimit=False,
                 ignore_shadow_ban=True,
+                # Note: we don't pass state_event_ids here because this triggers
+                # an additional query per event to look them up from the events table.
+                prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+                depth=depth,
             )
+
+            last_sent_event_id = sent_event.event_id
+            depth += 1
+
             return last_stream_id
 
         try:
@@ -1046,7 +1105,9 @@ class RoomCreationHandler:
         await send(etype=EventTypes.Create, content=creation_content)
 
         logger.debug("Sending %s in new room", EventTypes.Member)
-        await self.room_member_handler.update_membership(
+        # Room create event must exist at this point
+        assert last_sent_event_id is not None
+        member_event_id, _ = await self.room_member_handler.update_membership(
             creator,
             creator.user,
             room_id,
@@ -1054,7 +1115,10 @@ class RoomCreationHandler:
             ratelimit=ratelimit,
             content=creator_join_profile,
             new_room=True,
+            prev_event_ids=[last_sent_event_id],
+            depth=depth,
         )
+        last_sent_event_id = member_event_id
 
         # We treat the power levels override specially as this needs to be one
         # of the first events that get sent into a room.
@@ -1146,7 +1210,7 @@ class RoomCreationHandler:
                 content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
             )
 
-        return last_sent_stream_id
+        return last_sent_stream_id, last_sent_event_id, depth
 
     def _generate_room_id(self) -> str:
         """Generates a random room ID.
@@ -1228,13 +1292,16 @@ class RoomContextHandler:
         """
         user = requester.user
         if use_admin_priviledge:
-            await assert_user_is_admin(self.auth, requester.user)
+            await assert_user_is_admin(self.auth, requester)
 
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = await self.store.get_users_in_room(room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         async def filter_evts(events: List[EventBase]) -> List[EventBase]:
             if use_admin_priviledge:
@@ -1333,6 +1400,7 @@ class TimestampLookupHandler:
         self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.federation_client = hs.get_federation_client()
+        self.federation_event_handler = hs.get_federation_event_handler()
         self._storage_controllers = hs.get_storage_controllers()
 
     async def get_event_for_timestamp(
@@ -1375,6 +1443,7 @@ class TimestampLookupHandler:
         # the timestamp given and the event we were able to find locally
         is_event_next_to_backward_gap = False
         is_event_next_to_forward_gap = False
+        local_event = None
         if local_event_id:
             local_event = await self.store.get_event(
                 local_event_id, allow_none=False, allow_rejected=False
@@ -1406,17 +1475,16 @@ class TimestampLookupHandler:
                 timestamp,
             )
 
-            # Find other homeservers from the given state in the room
-            curr_state = await self._storage_controllers.state.get_current_state(
-                room_id
+            likely_domains = (
+                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
             )
-            curr_domains = get_domains_from_state(curr_state)
-            likely_domains = [
-                domain for domain, depth in curr_domains if domain != self.server_name
-            ]
 
             # Loop through each homeserver candidate until we get a succesful response
             for domain in likely_domains:
+                # We don't want to ask our own server for information we don't have
+                if domain == self.server_name:
+                    continue
+
                 try:
                     remote_response = await self.federation_client.timestamp_to_event(
                         domain, room_id, timestamp, direction
@@ -1427,41 +1495,74 @@ class TimestampLookupHandler:
                         remote_response,
                     )
 
-                    # TODO: Do we want to persist this as an extremity?
-                    # TODO: I think ideally, we would try to backfill from
-                    # this event and run this whole
-                    # `get_event_for_timestamp` function again to make sure
-                    # they didn't give us an event from their gappy history.
                     remote_event_id = remote_response.event_id
-                    origin_server_ts = remote_response.origin_server_ts
+                    remote_origin_server_ts = remote_response.origin_server_ts
+
+                    # Backfill this event so we can get a pagination token for
+                    # it with `/context` and paginate `/messages` from this
+                    # point.
+                    #
+                    # TODO: The requested timestamp may lie in a part of the
+                    #   event graph that the remote server *also* didn't have,
+                    #   in which case they will have returned another event
+                    #   which may be nowhere near the requested timestamp. In
+                    #   the future, we may need to reconcile that gap and ask
+                    #   other homeservers, and/or extend `/timestamp_to_event`
+                    #   to return events on *both* sides of the timestamp to
+                    #   help reconcile the gap faster.
+                    remote_event = (
+                        await self.federation_event_handler.backfill_event_id(
+                            domain, room_id, remote_event_id
+                        )
+                    )
+
+                    # XXX: When we see that the remote server is not trustworthy,
+                    # maybe we should not ask them first in the future.
+                    if remote_origin_server_ts != remote_event.origin_server_ts:
+                        logger.info(
+                            "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+                            domain,
+                            remote_event_id,
+                            remote_origin_server_ts,
+                            remote_event.origin_server_ts,
+                        )
 
                     # Only return the remote event if it's closer than the local event
                     if not local_event or (
-                        abs(origin_server_ts - timestamp)
+                        abs(remote_event.origin_server_ts - timestamp)
                         < abs(local_event.origin_server_ts - timestamp)
                     ):
-                        return remote_event_id, origin_server_ts
+                        logger.info(
+                            "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+                            remote_event_id,
+                            remote_event.origin_server_ts,
+                            timestamp,
+                            local_event.event_id if local_event else None,
+                            local_event.origin_server_ts if local_event else None,
+                        )
+                        return remote_event_id, remote_origin_server_ts
                 except (HttpResponseException, InvalidResponseError) as ex:
                     # Let's not put a high priority on some other homeserver
                     # failing to respond or giving a random response
                     logger.debug(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
                         domain,
                         type(ex).__name__,
                         ex,
                         ex.args,
                     )
-                except Exception as ex:
+                except Exception:
                     # But we do want to see some exceptions in our code
                     logger.warning(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
                         domain,
-                        type(ex).__name__,
-                        ex,
-                        ex.args,
+                        exc_info=True,
                     )
 
-        if not local_event_id:
+        # To appease mypy, we have to add both of these conditions to check for
+        # `None`. We only expect `local_event` to be `None` when
+        # `local_event_id` is `None` but mypy isn't as smart and assuming as us.
+        if not local_event_id or not local_event:
             raise SynapseError(
                 404,
                 "Unable to find event from %s in direction %s" % (timestamp, direction),
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 183d4ae3c4..bb0bdb8e6f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -25,6 +25,7 @@ from synapse.api.constants import (
     GuestAccess,
     HistoryVisibility,
     JoinRules,
+    PublicRoomsFilterFields,
 )
 from synapse.api.errors import (
     Codes,
@@ -181,6 +182,7 @@ class RoomListHandler:
                 == HistoryVisibility.WORLD_READABLE,
                 "guest_can_join": room["guest_access"] == "can_join",
                 "join_rule": room["join_rules"],
+                "room_type": room["room_type"],
             }
 
             # Filter out Nones – rather omit the field altogether
@@ -239,7 +241,9 @@ class RoomListHandler:
         response["chunk"] = results
 
         response["total_room_count_estimate"] = await self.store.count_public_rooms(
-            network_tuple, ignore_non_federatable=from_federation
+            network_tuple,
+            ignore_non_federatable=from_federation,
+            search_filter=search_filter,
         )
 
         return response
@@ -508,8 +512,21 @@ class RoomListNextBatch:
 
 
 def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
-    if search_filter and search_filter.get("generic_search_term", None):
-        generic_search_term = search_filter["generic_search_term"].upper()
+    """Determines whether the given search filter matches a room entry returned over
+    federation.
+
+    Only used if the remote server does not support MSC2197 remote-filtered search, and
+    hence does not support MSC3827 filtering of `/publicRooms` by room type either.
+
+    In this case, we cannot apply the `room_type` filter since no `room_type` field is
+    returned.
+    """
+    if search_filter and search_filter.get(
+        PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+    ):
+        generic_search_term = search_filter[
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM
+        ].upper()
         if generic_search_term in room_entry.get("name", "").upper():
             return True
         elif generic_search_term in room_entry.get("topic", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d1199a0644..5d4adf5bfd 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,18 +26,14 @@ from synapse.api.constants import (
     GuestAccess,
     Membership,
 )
-from synapse.api.errors import (
-    AuthError,
-    Codes,
-    LimitExceededError,
-    ShadowBanError,
-    SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.logging import opentracing
+from synapse.module_api import NOT_SPAM
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -99,26 +95,57 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
+        # Tracks joins from local users to rooms this server isn't a member of.
+        # I.e. joins this server makes by requesting /make_join /send_join from
+        # another server.
         self._join_rate_limiter_remote = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
+        # TODO: find a better place to keep this Ratelimiter.
+        #   It needs to be
+        #    - written to by event persistence code
+        #    - written to by something which can snoop on replication streams
+        #    - read by the RoomMemberHandler to rate limit joins from local users
+        #    - read by the FederationServer to rate limit make_joins and send_joins from
+        #      other homeservers
+        #   I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+        self._join_rate_per_room_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+        )
 
+        # Ratelimiter for invites, keyed by room (across all issuers, all
+        # recipients).
         self._invites_per_room_limiter = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
         )
-        self._invites_per_user_limiter = Ratelimiter(
+
+        # Ratelimiter for invites, keyed by recipient (across all rooms, all
+        # issuers).
+        self._invites_per_recipient_limiter = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
         )
 
+        # Ratelimiter for invites, keyed by issuer (across all rooms, all
+        # recipients).
+        self._invites_per_issuer_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count,
+        )
+
         self._third_party_invite_limiter = Ratelimiter(
             store=self.store,
             clock=self.clock,
@@ -127,6 +154,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
+        hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+    def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+        """Notify the rate limiter that a room join has occurred.
+
+        Use this to inform the RoomMemberHandler about joins that have either
+        - taken place on another homeserver, or
+        - on another worker in this homeserver.
+        Joins actioned by this worker should use the usual `ratelimit` method, which
+        checks the limit and increments the counter in one go.
+        """
+        self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
 
     @abc.abstractmethod
     async def _remote_join(
@@ -140,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """Try and join a room that this server is not in
 
         Args:
-            requester
+            requester: The user making the request, according to the access token.
             remote_room_hosts: List of servers that can be used to join via.
             room_id: Room that we are trying to join
             user: User who is trying to join
@@ -263,7 +302,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if room_id:
             await self._invites_per_room_limiter.ratelimit(requester, room_id)
 
-        await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
+        await self._invites_per_recipient_limiter.ratelimit(requester, invitee_user_id)
+        if requester is not None:
+            await self._invites_per_issuer_limiter.ratelimit(requester)
 
     async def _local_membership_update(
         self,
@@ -274,6 +315,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
@@ -304,6 +346,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
             txn_id:
             ratelimit:
@@ -359,6 +404,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             allow_no_prev_events=allow_no_prev_events,
             prev_event_ids=prev_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             require_consent=require_consent,
             outlier=outlier,
             historical=historical,
@@ -379,24 +425,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # Only rate-limit if the user actually joined the room, otherwise we'll end
             # up blocking profile updates.
             if newly_joined and ratelimit:
-                time_now_s = self.clock.time()
-                (
-                    allowed,
-                    time_allowed,
-                ) = await self._join_rate_limiter_local.can_do_action(requester)
-
-                if not allowed:
-                    raise LimitExceededError(
-                        retry_after_ms=int(1000 * (time_allowed - time_now_s))
-                    )
-
-        result_event = await self.event_creation_handler.handle_new_client_event(
-            requester,
-            event,
-            context,
-            extra_users=[target],
-            ratelimit=ratelimit,
-        )
+                await self._join_rate_limiter_local.ratelimit(requester)
+                await self._join_rate_per_room_limiter.ratelimit(
+                    requester, key=room_id, update=False
+                )
+        with opentracing.start_active_span("handle_new_client_event"):
+            result_event = await self.event_creation_handler.handle_new_client_event(
+                requester,
+                event,
+                context,
+                extra_users=[target],
+                ratelimit=ratelimit,
+            )
 
         if event.membership == Membership.LEAVE:
             if prev_member_event_id:
@@ -464,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -499,6 +540,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -521,24 +565,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # by application services), and then by room ID.
         async with self.member_as_limiter.queue(as_id):
             async with self.member_linearizer.queue(key):
-                result = await self.update_membership_locked(
-                    requester,
-                    target,
-                    room_id,
-                    action,
-                    txn_id=txn_id,
-                    remote_room_hosts=remote_room_hosts,
-                    third_party_signed=third_party_signed,
-                    ratelimit=ratelimit,
-                    content=content,
-                    new_room=new_room,
-                    require_consent=require_consent,
-                    outlier=outlier,
-                    historical=historical,
-                    allow_no_prev_events=allow_no_prev_events,
-                    prev_event_ids=prev_event_ids,
-                    state_event_ids=state_event_ids,
-                )
+                with opentracing.start_active_span("update_membership_locked"):
+                    result = await self.update_membership_locked(
+                        requester,
+                        target,
+                        room_id,
+                        action,
+                        txn_id=txn_id,
+                        remote_room_hosts=remote_room_hosts,
+                        third_party_signed=third_party_signed,
+                        ratelimit=ratelimit,
+                        content=content,
+                        new_room=new_room,
+                        require_consent=require_consent,
+                        outlier=outlier,
+                        historical=historical,
+                        allow_no_prev_events=allow_no_prev_events,
+                        prev_event_ids=prev_event_ids,
+                        state_event_ids=state_event_ids,
+                        depth=depth,
+                    )
 
         return result
 
@@ -560,6 +606,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -597,10 +644,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
         """
+
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -638,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 errcode=Codes.BAD_JSON,
             )
 
-        if "avatar_url" in content:
+        if "avatar_url" in content and content.get("avatar_url") is not None:
             if not await self.profile_handler.check_avatar_size_and_mime_type(
                 content["avatar_url"],
             ):
@@ -683,7 +734,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if target_id == self._server_notices_mxid:
                 raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
-            block_invite = False
+            block_invite_result = None
 
             if (
                 self._server_notices_mxid is not None
@@ -693,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 is_requester_admin = True
 
             else:
-                is_requester_admin = await self.auth.is_server_admin(requester.user)
+                is_requester_admin = await self.auth.is_server_admin(requester)
 
             if not is_requester_admin:
                 if self.config.server.block_non_admin_invites:
@@ -701,16 +752,22 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                         "Blocking invite: user is not admin and non-admin "
                         "invites disabled"
                     )
-                    block_invite = True
+                    block_invite_result = (Codes.FORBIDDEN, {})
 
-                if not await self.spam_checker.user_may_invite(
+                spam_check = await self.spam_checker.user_may_invite(
                     requester.user.to_string(), target_id, room_id
-                ):
+                )
+                if spam_check != NOT_SPAM:
                     logger.info("Blocking invite due to spam checker")
-                    block_invite = True
+                    block_invite_result = spam_check
 
-            if block_invite:
-                raise SynapseError(403, "Invites have been disabled on this server")
+            if block_invite_result is not None:
+                raise SynapseError(
+                    403,
+                    "Invites have been disabled on this server",
+                    errcode=block_invite_result[0],
+                    additional_fields=block_invite_result[1],
+                )
 
         # An empty prev_events list is allowed as long as the auth_event_ids are present
         if prev_event_ids is not None:
@@ -724,6 +781,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 allow_no_prev_events=allow_no_prev_events,
                 prev_event_ids=prev_event_ids,
                 state_event_ids=state_event_ids,
+                depth=depth,
                 content=content,
                 require_consent=require_consent,
                 outlier=outlier,
@@ -732,14 +790,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        current_state_ids = await self.state_handler.get_current_state_ids(
-            room_id, latest_event_ids=latest_event_ids
+        state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -790,11 +848,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(current_state_ids)
+        is_host_in_room = await self._is_host_in_room(state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(current_state_ids)
+                guest_can_join = await self._can_guest_join(state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -810,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 bypass_spam_checker = True
 
             else:
-                bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+                bypass_spam_checker = await self.auth.is_server_admin(requester)
 
             inviter = await self._get_inviter(target.to_string(), room_id)
             if (
@@ -818,30 +876,37 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 # We assume that if the spam checker allowed the user to create
                 # a room then they're allowed to join it.
                 and not new_room
-                and not await self.spam_checker.user_may_join_room(
+            ):
+                spam_check = await self.spam_checker.user_may_join_room(
                     target.to_string(), room_id, is_invited=inviter is not None
                 )
-            ):
-                raise SynapseError(403, "Not allowed to join this room")
+                if spam_check != NOT_SPAM:
+                    raise SynapseError(
+                        403,
+                        "Not allowed to join this room",
+                        errcode=spam_check[0],
+                        additional_fields=spam_check[1],
+                    )
 
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
-                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+                target.to_string(),
+                room_id,
+                remote_room_hosts,
+                content,
+                is_host_in_room,
+                state_before_join,
             )
             if remote_join:
                 if ratelimit:
-                    time_now_s = self.clock.time()
-                    (
-                        allowed,
-                        time_allowed,
-                    ) = await self._join_rate_limiter_remote.can_do_action(
+                    await self._join_rate_limiter_remote.ratelimit(
                         requester,
                     )
-
-                    if not allowed:
-                        raise LimitExceededError(
-                            retry_after_ms=int(1000 * (time_allowed - time_now_s))
-                        )
+                    await self._join_rate_per_room_limiter.ratelimit(
+                        requester,
+                        key=room_id,
+                        update=False,
+                    )
 
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
@@ -849,10 +914,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
                 content["membership"] = Membership.JOIN
 
-                profile = self.profile_handler
-                if not content_specified:
-                    content["displayname"] = await profile.get_displayname(target)
-                    content["avatar_url"] = await profile.get_avatar_url(target)
+                try:
+                    profile = self.profile_handler
+                    if not content_specified:
+                        content["displayname"] = await profile.get_displayname(target)
+                        content["avatar_url"] = await profile.get_avatar_url(target)
+                except Exception as e:
+                    logger.info(
+                        "Failed to get profile information while processing remote join for %r: %s",
+                        target,
+                        e,
+                    )
 
                 if requester.is_guest:
                     content["kind"] = "guest"
@@ -929,11 +1001,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
                 content["membership"] = Membership.KNOCK
 
-                profile = self.profile_handler
-                if "displayname" not in content:
-                    content["displayname"] = await profile.get_displayname(target)
-                if "avatar_url" not in content:
-                    content["avatar_url"] = await profile.get_avatar_url(target)
+                try:
+                    profile = self.profile_handler
+                    if "displayname" not in content:
+                        content["displayname"] = await profile.get_displayname(target)
+                    if "avatar_url" not in content:
+                        content["avatar_url"] = await profile.get_avatar_url(target)
+                except Exception as e:
+                    logger.info(
+                        "Failed to get profile information while processing remote knock for %r: %s",
+                        target,
+                        e,
+                    )
 
                 return await self.remote_knock(
                     remote_room_hosts, room_id, target, content
@@ -948,6 +1027,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             ratelimit=ratelimit,
             prev_event_ids=latest_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             content=content,
             require_consent=require_consent,
             outlier=outlier,
@@ -960,6 +1040,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
+        state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -979,6 +1060,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
+            state_before_join: The state before the join event (i.e. the resolution of
+                the states after its parent events).
 
         Returns:
             A tuple of:
@@ -995,20 +1078,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
-            room_id
-        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            current_state_ids, room_version
+            state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1023,10 +1103,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(current_state_ids.values())
+        event_map = await self.store.get_events(state_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in current_state_ids.items()
+            for state_key, event_id in state_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1040,7 +1120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            current_state_ids, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1050,7 +1130,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            current_state_ids,
+            state_before_join,
         )
 
         return False, []
@@ -1302,8 +1382,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         id_server: str,
         requester: Requester,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[str, int]:
         """Invite a 3PID to a room.
 
         Args:
@@ -1315,16 +1397,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             requester: The user making the request.
             txn_id: The transaction ID this is part of, or None if this is not
                 part of a transaction.
-            id_access_token: The optional identity server access token.
+            id_access_token: Identity server access token.
+            depth: Override the depth used to order the event in the DAG.
+            prev_event_ids: The event IDs to use as the prev events
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
-             The new stream ID.
+            Tuple of event ID and stream ordering position
 
         Raises:
             ShadowBanError if the requester has been shadow-banned.
         """
         if self.config.server.block_non_admin_invites:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
             if not is_requester_admin:
                 raise SynapseError(
                     403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -1364,20 +1450,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # We don't check the invite against the spamchecker(s) here (through
             # user_may_invite) because we'll do it further down the line anyway (in
             # update_membership_locked).
-            _, stream_id = await self.update_membership(
+            event_id, stream_id = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
         else:
             # Check if the spamchecker(s) allow this invite to go through.
-            if not await self.spam_checker.user_may_send_3pid_invite(
+            spam_check = await self.spam_checker.user_may_send_3pid_invite(
                 inviter_userid=requester.user.to_string(),
                 medium=medium,
                 address=address,
                 room_id=room_id,
-            ):
-                raise SynapseError(403, "Cannot send threepid invite")
+            )
+            if spam_check != NOT_SPAM:
+                raise SynapseError(
+                    403,
+                    "Cannot send threepid invite",
+                    errcode=spam_check[0],
+                    additional_fields=spam_check[1],
+                )
 
-            stream_id = await self._make_and_store_3pid_invite(
+            event, stream_id = await self._make_and_store_3pid_invite(
                 requester,
                 id_server,
                 medium,
@@ -1386,9 +1478,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 inviter,
                 txn_id=txn_id,
                 id_access_token=id_access_token,
+                prev_event_ids=prev_event_ids,
+                depth=depth,
             )
+            event_id = event.event_id
 
-        return stream_id
+        return event_id, stream_id
 
     async def _make_and_store_3pid_invite(
         self,
@@ -1399,8 +1494,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         user: UserID,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
         room_state = await self._storage_controllers.state.get_current_state(
             room_id,
             StateFilter.from_types(
@@ -1493,8 +1590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             },
             ratelimit=False,
             txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            depth=depth,
         )
-        return stream_id
+        return event, stream_id
 
     async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
@@ -1521,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     async def _is_server_notice_room(self, room_id: str) -> bool:
         if self._server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self._server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self._server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
 
 class RoomMemberMasterHandler(RoomMemberHandler):
@@ -1583,14 +1684,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise SynapseError(
+                404,
+                "Can't join remote room because no servers "
+                "that are in the room have been provided.",
+            )
 
         check_complexity = self.hs.config.server.limit_remote_rooms.enabled
         if (
             check_complexity
             and self.hs.config.server.limit_remote_rooms.admins_can_join
         ):
-            check_complexity = not await self.auth.is_server_admin(user)
+            check_complexity = not await self.store.is_server_admin(user)
 
         if check_complexity:
             # Fetch the room complexity
@@ -1820,8 +1925,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]:
             raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
 
-        if membership:
-            await self.store.forget(user_id, room_id)
+        # In normal case this call is only required if `membership` is not `None`.
+        # But: After the last member had left the room, the background update
+        # `_background_remove_left_rooms` is deleting rows related to this room from
+        # the table `current_state_events` and `get_current_state_events` is `None`.
+        await self.store.forget(user_id, room_id)
 
 
 def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 13098f56ed..732b0310bc 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,11 +28,11 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.errors import (
-    AuthError,
     Codes,
     NotFoundError,
     StoreError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
 from synapse.api.ratelimiting import Ratelimiter
@@ -175,10 +175,11 @@ class RoomSummaryHandler:
 
         # First of all, check that the room is accessible.
         if not await self._is_local_room_accessible(requested_room_id, requester):
-            raise AuthError(
+            raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
                 % (requester, requested_room_id),
+                errcode=Codes.NOT_JOINED,
             )
 
         # If this is continuing a previous session, pull the persisted data.
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a305a66860..e2844799e8 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -23,10 +23,12 @@ from pkg_resources import parse_version
 
 import twisted
 from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
+from twisted.internet.interfaces import IOpenSSLContextFactory
+from twisted.internet.ssl import optionsForClientTLS
 from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.types import ISynapseReactor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
 
 
 async def _sendmail(
-    reactor: IReactorTCP,
+    reactor: ISynapseReactor,
     smtphost: str,
     smtpport: int,
     from_addr: str,
@@ -59,6 +61,7 @@ async def _sendmail(
     require_auth: bool = False,
     require_tls: bool = False,
     enable_tls: bool = True,
+    force_tls: bool = False,
 ) -> None:
     """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
 
@@ -73,8 +76,9 @@ async def _sendmail(
         password: password to give when authenticating
         require_auth: if auth is not offered, fail the request
         require_tls: if TLS is not offered, fail the reqest
-        enable_tls: True to enable TLS. If this is False and require_tls is True,
+        enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
            the request will fail.
+        force_tls: True to enable Implicit TLS.
     """
     msg = BytesIO(msg_bytes)
     d: "Deferred[object]" = Deferred()
@@ -105,13 +109,23 @@ async def _sendmail(
         # set to enable TLS.
         factory = build_sender_factory(hostname=smtphost if enable_tls else None)
 
-    reactor.connectTCP(
-        smtphost,
-        smtpport,
-        factory,
-        timeout=30,
-        bindAddress=None,
-    )
+    if force_tls:
+        reactor.connectSSL(
+            smtphost,
+            smtpport,
+            factory,
+            optionsForClientTLS(smtphost),
+            timeout=30,
+            bindAddress=None,
+        )
+    else:
+        reactor.connectTCP(
+            smtphost,
+            smtpport,
+            factory,
+            timeout=30,
+            bindAddress=None,
+        )
 
     await make_deferred_yieldable(d)
 
@@ -132,6 +146,7 @@ class SendEmailHandler:
         self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
         self._require_transport_security = hs.config.email.require_transport_security
         self._enable_tls = hs.config.email.enable_smtp_tls
+        self._force_tls = hs.config.email.force_tls
 
         self._sendmail = _sendmail
 
@@ -189,4 +204,5 @@ class SendEmailHandler:
             require_auth=self._smtp_user is not None,
             require_tls=self._require_transport_security,
             enable_tls=self._enable_tls,
+            force_tls=self._force_tls,
         )
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index f45e06eb0e..5c01482acf 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -271,6 +271,9 @@ class StatsHandler:
                 room_state["is_federatable"] = (
                     event_content.get(EventContentFields.FEDERATE, True) is True
                 )
+                room_type = event_content.get(EventContentFields.ROOM_TYPE)
+                if isinstance(room_type, str):
+                    room_state["room_type"] = room_type
             elif typ == EventTypes.JoinRules:
                 room_state["join_rules"] = event_content.get("join_rule")
             elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4ead79f97..2d95b1fa24 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,12 +13,24 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -89,7 +101,7 @@ class SyncConfig:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
     prev_batch: StreamToken
-    events: List[EventBase]
+    events: Sequence[EventBase]
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
@@ -237,9 +249,10 @@ class SyncHandler:
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
-        self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
+        self._device_handler = hs.get_device_handler()
 
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
@@ -280,7 +293,7 @@ class SyncHandler:
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        await self.auth.check_auth_blocking(requester=requester)
+        await self.auth_blocking.check_auth_blocking(requester=requester)
 
         res = await self.response_cache.wrap(
             sync_config.request_key,
@@ -506,10 +519,17 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `recents`, so partial state is only a problem when a membership
+                    # event turns up in `recents` but has not made it into the current
+                    # state.
                     current_state_ids_map = (
-                        await self._state_storage_controller.get_current_state_ids(
-                            room_id
-                        )
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -578,7 +598,13 @@ class SyncHandler:
                 if any(e.is_state() for e in loaded_recents):
                     # FIXME(faster_joins): We use the partial state here as
                     # we don't want to block `/sync` on finishing a lazy join.
-                    # Is this the correct way of doing it?
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `loaded_recents`, so partial state is only a problem when a
+                    # membership event turns up in `loaded_recents` but has not made it
+                    # into the current state.
                     current_state_ids_map = (
                         await self.store.get_partial_current_state_ids(room_id)
                     )
@@ -626,7 +652,10 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the room state after the given event
@@ -634,9 +663,14 @@ class SyncHandler:
         Args:
             event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id, state_filter=state_filter or StateFilter.all()
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
 
         # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -659,6 +693,7 @@ class SyncHandler:
         room_id: str,
         stream_position: StreamToken,
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the room state at a particular stream position
 
@@ -666,6 +701,9 @@ class SyncHandler:
             room_id: room for which to get state
             stream_position: point at which to get state
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
         """
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
@@ -677,7 +715,9 @@ class SyncHandler:
 
         if last_event_id:
             state = await self.get_state_after_event(
-                last_event_id, state_filter=state_filter or StateFilter.all()
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
             )
 
         else:
@@ -851,16 +891,26 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """Works out the difference in state between the start of the timeline
-        and the previous sync.
+        """Works out the difference in state between the end of the previous sync and
+        the start of the timeline.
 
         Args:
             room_id:
             batch: The timeline batch for the room that will be sent to the user.
             sync_config:
-            since_token: Token of the end of the previous batch. May be None.
+            since_token: Token of the end of the previous batch. May be `None`.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+                `lazy_load_members` still applies when `full_state` is `True`.
+
+        Returns:
+            The state to return in the sync response for the room.
+
+            Clients will overlay this onto the state at the end of the previous sync to
+            arrive at the state at the start of the timeline.
+
+            Clients will then overlay state events in the timeline to arrive at the
+            state at the end of the timeline, in preparation for the next sync.
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -868,8 +918,17 @@ class SyncHandler:
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
+            # The memberships needed for events in the timeline.
+            # Only calculated when `lazy_load_members` is on.
+            members_to_fetch: Optional[Set[str]] = None
+
+            # A dictionary mapping user IDs to the first event in the timeline sent by
+            # them. Only calculated when `lazy_load_members` is on.
+            first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
 
-            members_to_fetch = None
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
+            timeline_state: StateMap[str]
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -880,10 +939,23 @@ class SyncHandler:
                 # We only request state for the members needed to display the
                 # timeline:
 
-                members_to_fetch = {
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in batch.events
-                }
+                timeline_state = {}
+
+                members_to_fetch = set()
+                first_event_by_sender_map = {}
+                for event in batch.events:
+                    # Build the map from user IDs to the first timeline event they sent.
+                    if event.sender not in first_event_by_sender_map:
+                        first_event_by_sender_map[event.sender] = event
+
+                    # We need the event's sender, unless their membership was in a
+                    # previous timeline event.
+                    if (EventTypes.Member, event.sender) not in timeline_state:
+                        members_to_fetch.add(event.sender)
+                    # FIXME: we also care about invite targets etc.
+
+                    if event.is_state():
+                        timeline_state[(event.type, event.state_key)] = event.event_id
 
                 if full_state:
                     # always make sure we LL ourselves so we know we're in the room
@@ -893,55 +965,80 @@ class SyncHandler:
                     members_to_fetch.add(sync_config.user.to_string())
 
                 state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+                # We are happy to use partial state to compute the `/sync` response.
+                # Since partial state may not include the lazy-loaded memberships we
+                # require, we fix up the state response afterwards with memberships from
+                # auth events.
+                await_full_state = False
             else:
+                timeline_state = {
+                    (event.type, event.state_key): event.event_id
+                    for event in batch.events
+                    if event.is_state()
+                }
+
                 state_filter = StateFilter.all()
+                await_full_state = True
 
-            timeline_state = {
-                (event.type, event.state_key): event.event_id
-                for event in batch.events
-                if event.is_state()
-            }
+            # Now calculate the state to return in the sync response for the room.
+            # This is more or less the change in state between the end of the previous
+            # sync's timeline and the start of the current sync's timeline.
+            # See the docstring above for details.
+            state_ids: StateMap[str]
 
             if full_state:
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
-                    state_ids = (
+                    state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                 else:
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
-                    state_ids = current_state_ids
+                    state_at_timeline_start = state_at_timeline_end
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state_ids,
-                    previous={},
-                    current=current_state_ids,
+                    timeline_start=state_at_timeline_start,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end={},
                     lazy_load_members=lazy_load_members,
                 )
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_start = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 # for now, we disable LL for gappy syncs - see
@@ -963,28 +1060,35 @@ class SyncHandler:
                 # is indeed the case.
                 assert since_token is not None
                 state_at_previous_sync = await self.get_state_at(
-                    room_id, stream_position=since_token, state_filter=state_filter
+                    room_id,
+                    stream_position=since_token,
+                    state_filter=state_filter,
+                    await_full_state=await_full_state,
                 )
 
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
-                    # Its not clear how we get here, but empirically we do
-                    # (#5407). Logging has been added elsewhere to try and
-                    # figure out where this state comes from.
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    # We can get here if the user has ignored the senders of all
+                    # the recent events.
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
-                    previous=state_at_previous_sync,
-                    current=current_state_ids,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end=state_at_previous_sync,
                     # we have to include LL members in case LL initial sync missed them
                     lazy_load_members=lazy_load_members,
                 )
@@ -1007,8 +1111,30 @@ class SyncHandler:
                                 (EventTypes.Member, member)
                                 for member in members_to_fetch
                             ),
+                            await_full_state=False,
                         )
 
+            # If we only have partial state for the room, `state_ids` may be missing the
+            # memberships we wanted. We attempt to find some by digging through the auth
+            # events of timeline events.
+            if lazy_load_members and await self.store.is_partial_state_room(room_id):
+                assert members_to_fetch is not None
+                assert first_event_by_sender_map is not None
+
+                additional_state_ids = (
+                    await self._find_missing_partial_state_memberships(
+                        room_id, members_to_fetch, first_event_by_sender_map, state_ids
+                    )
+                )
+                state_ids = {**state_ids, **additional_state_ids}
+
+            # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+            # the memberships of all event senders in the timeline. This is because we
+            # may not have sent the memberships in a previous sync.
+
+            # When `include_redundant_members` is on, we send all the lazy-loaded
+            # memberships of event senders. Otherwise we make an effort to limit the set
+            # of memberships we send to those that we have not already sent to this client.
             if lazy_load_members and not include_redundant_members:
                 cache_key = (sync_config.user.to_string(), sync_config.device_id)
                 cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1050,18 +1176,107 @@ class SyncHandler:
             if e.type != EventTypes.Aliases  # until MSC2261 or alternative solution
         }
 
+    async def _find_missing_partial_state_memberships(
+        self,
+        room_id: str,
+        members_to_fetch: Collection[str],
+        events_with_membership_auth: Mapping[str, EventBase],
+        found_state_ids: StateMap[str],
+    ) -> StateMap[str]:
+        """Finds missing memberships from a set of auth events and returns them as a
+        state map.
+
+        Args:
+            room_id: The partial state room to find the remaining memberships for.
+            members_to_fetch: The memberships to find.
+            events_with_membership_auth: A mapping from user IDs to events whose auth
+                events are known to contain their membership.
+            found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+                memberships that have been previously found. Entries in
+                `members_to_fetch` that have a membership in `found_state_ids` are
+                ignored.
+
+        Returns:
+            A dict from ("m.room.member", state_key) -> state_event_id, containing the
+            memberships missing from `found_state_ids`.
+
+        Raises:
+            KeyError: if `events_with_membership_auth` does not have an entry for a
+                missing membership. Memberships in `found_state_ids` do not need an
+                entry in `events_with_membership_auth`.
+        """
+        additional_state_ids: MutableStateMap[str] = {}
+
+        # Tracks the missing members for logging purposes.
+        missing_members = set()
+
+        # Identify memberships missing from `found_state_ids` and pick out the auth
+        # events in which to look for them.
+        auth_event_ids: Set[str] = set()
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            missing_members.add(member)
+            event_with_membership_auth = events_with_membership_auth[member]
+            auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+        auth_events = await self.store.get_events(auth_event_ids)
+
+        # Run through the missing memberships once more, picking out the memberships
+        # from the pile of auth events we have just fetched.
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            event_with_membership_auth = events_with_membership_auth[member]
+
+            # Dig through the auth events to find the desired membership.
+            for auth_event_id in event_with_membership_auth.auth_event_ids():
+                # We only store events once we have all their auth events,
+                # so the auth event must be in the pile we have just
+                # fetched.
+                auth_event = auth_events[auth_event_id]
+
+                if (
+                    auth_event.type == EventTypes.Member
+                    and auth_event.state_key == member
+                ):
+                    missing_members.remove(member)
+                    additional_state_ids[
+                        (EventTypes.Member, member)
+                    ] = auth_event.event_id
+                    break
+
+        if missing_members:
+            # There really shouldn't be any missing memberships now. Either:
+            #  * we couldn't find an auth event, which shouldn't happen because we do
+            #    not persist events with persisting their auth events first, or
+            #  * the set of auth events did not contain a membership we wanted, which
+            #    means our caller didn't compute the events in `members_to_fetch`
+            #    correctly, or we somehow accepted an event whose auth events were
+            #    dodgy.
+            logger.error(
+                "Failed to find memberships for %s in partial state room "
+                "%s in the auth events of %s.",
+                missing_members,
+                room_id,
+                [
+                    events_with_membership_auth[member].event_id
+                    for member in missing_members
+                ],
+            )
+
+        return additional_state_ids
+
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
         with Measure(self.clock, "unread_notifs_for_room_id"):
-            last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
-                user_id=sync_config.user.to_string(),
-                room_id=room_id,
-                receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
-            )
 
             return await self.store.get_unread_event_push_actions_by_room_for_user(
-                room_id, sync_config.user.to_string(), last_unread_event_id
+                room_id,
+                sync_config.user.to_string(),
             )
 
     async def generate_sync_result(
@@ -1272,21 +1487,11 @@ class SyncHandler:
                     ):
                         users_that_have_changed.add(changed_user_id)
             else:
-                users_who_share_room = (
-                    await self.store.get_users_who_share_room_with_user(user_id)
-                )
-
-                # Always tell the user about their own devices. We check as the user
-                # ID is almost certainly already included (unless they're not in any
-                # rooms) and taking a copy of the set is relatively expensive.
-                if user_id not in users_who_share_room:
-                    users_who_share_room = set(users_who_share_room)
-                    users_who_share_room.add(user_id)
-
-                tracked_users = users_who_share_room
                 users_that_have_changed = (
-                    await self.store.get_users_whose_devices_changed(
-                        since_token.device_list_key, tracked_users
+                    await self._device_handler.get_device_changes_in_shared_rooms(
+                        user_id,
+                        sync_result_builder.joined_room_ids,
+                        from_token=since_token,
                     )
                 )
 
@@ -1549,15 +1754,13 @@ class SyncHandler:
         ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
+                sync_result_builder, ignored_users
             )
             tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            room_changes = await self._get_all_rooms(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
-            )
+            room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1636,13 +1839,14 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        excluded_rooms: List[str],
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
         This function is a first pass at generating the rooms part of the sync response.
         It determines which rooms have changed during the sync period, and categorises
-        them into four buckets: "knock", "invite", "join" and "leave".
+        them into four buckets: "knock", "invite", "join" and "leave". It also excludes
+        from that list any room that appears in the list of rooms to exclude from sync
+        results in the server configuration.
 
         1. Finds all membership changes for the user in the sync period (from
            `since_token` up to `now_token`).
@@ -1668,7 +1872,7 @@ class SyncHandler:
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
         membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key, excluded_rooms
+            user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1709,7 +1913,11 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(
+                    room_id,
+                    since_token,
+                    state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+                )
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
@@ -1735,7 +1943,13 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(
+                            room_id,
+                            since_token,
+                            state_filter=StateFilter.from_types(
+                                [(EventTypes.Member, user_id)]
+                            ),
+                        )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
@@ -1875,7 +2089,6 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        ignored_rooms: List[str],
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1897,7 +2110,7 @@ class SyncHandler:
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id,
             membership_list=Membership.LIST,
-            excluded_rooms=ignored_rooms,
+            excluded_rooms=self.rooms_to_exclude,
         )
 
         room_entries = []
@@ -2163,7 +2376,9 @@ class SyncHandler:
                 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, room_key: RoomStreamToken
+        self,
+        user_id: str,
+        room_key: RoomStreamToken,
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -2189,7 +2404,12 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
+        # We also need to check whether the room should be excluded from sync
+        # responses as per the homeserver config.
         for joined_room in joined_rooms:
+            if joined_room.room_id in self.rooms_to_exclude:
+                continue
+
             if not joined_room.event_pos.persisted_after(room_key):
                 joined_room_ids.add(joined_room.room_id)
                 continue
@@ -2201,10 +2421,10 @@ class SyncHandler:
                     joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(
+            user_ids_in_room = await self.state.get_current_user_ids_in_room(
                 joined_room.room_id, extrems
             )
-            if user_id in users_in_room:
+            if user_id in user_ids_in_room:
                 joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
@@ -2224,8 +2444,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
 def _calculate_state(
     timeline_contains: StateMap[str],
     timeline_start: StateMap[str],
-    previous: StateMap[str],
-    current: StateMap[str],
+    timeline_end: StateMap[str],
+    previous_timeline_end: StateMap[str],
     lazy_load_members: bool,
 ) -> StateMap[str]:
     """Works out what state to include in a sync response.
@@ -2233,45 +2453,50 @@ def _calculate_state(
     Args:
         timeline_contains: state in the timeline
         timeline_start: state at the start of the timeline
-        previous: state at the end of the previous sync (or empty dict
+        timeline_end: state at the end of the timeline
+        previous_timeline_end: state at the end of the previous sync (or empty dict
             if this is an initial sync)
-        current: state at the end of the timeline
         lazy_load_members: whether to return members from timeline_start
             or not.  assumes that timeline_start has already been filtered to
             include only the members the client needs to know about.
     """
-    event_id_to_key = {
-        e: key
-        for key, e in itertools.chain(
+    event_id_to_state_key = {
+        event_id: state_key
+        for state_key, event_id in itertools.chain(
             timeline_contains.items(),
-            previous.items(),
             timeline_start.items(),
-            current.items(),
+            timeline_end.items(),
+            previous_timeline_end.items(),
         )
     }
 
-    c_ids = set(current.values())
-    ts_ids = set(timeline_start.values())
-    p_ids = set(previous.values())
-    tc_ids = set(timeline_contains.values())
+    timeline_end_ids = set(timeline_end.values())
+    timeline_start_ids = set(timeline_start.values())
+    previous_timeline_end_ids = set(previous_timeline_end.values())
+    timeline_contains_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
     # as we may not have sent them to the client before.  We find these membership
     # events by filtering them out of timeline_start, which has already been filtered
     # to only include membership events for the senders in the timeline.
-    # In practice, we can do this by removing them from the p_ids list,
-    # which is the list of relevant state we know we have already sent to the client.
+    # In practice, we can do this by removing them from the previous_timeline_end_ids
+    # list, which is the list of relevant state we know we have already sent to the
+    # client.
     # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
-        p_ids.difference_update(
+        previous_timeline_end_ids.difference_update(
             e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (
+        (timeline_end_ids | timeline_start_ids)
+        - previous_timeline_end_ids
+        - timeline_contains_ids
+    )
 
-    return {event_id_to_key[e]: e for e in state_ids}
+    return {event_id_to_state_key[e]: e for e in state_ids}
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d104ea07fe..a4cd8b8f0c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str, timeout: int
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has started typing in %s", target_user_id, room_id)
 
@@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has stopped typing in %s", target_user_id, room_id)
 
@@ -364,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             )
             return
 
-        users = await self.store.get_users_in_room(room_id)
-        domains = {get_domain_from_id(u) for u in users}
+        domains = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
 
         if self.server_name in domains:
             logger.info("Got typing update from %s: %r", user_id, content)
@@ -489,8 +488,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
             handler = self.get_typing_handler()
 
             events = []
-            for room_id in handler._room_serials.keys():
-                if handler._room_serials[room_id] <= from_key:
+
+            # Work on a copy of things here as these may change in the handler while
+            # waiting for the AS `is_interested_in_room` call to complete.
+            # Shallow copy is safe as no nested data is present.
+            latest_room_serial = handler._latest_room_serial
+            room_serials = handler._room_serials.copy()
+
+            for room_id, serial in room_serials.items():
+                if serial <= from_key:
                     continue
 
                 if not await service.is_interested_in_room(room_id, self._main_store):
@@ -498,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
 
                 events.append(self._make_event_for(room_id))
 
-            return events, handler._latest_room_serial
+            return events, latest_room_serial
 
     async def get_new_events(
         self,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 05cebb5d4d..a744d68c64 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
 
-        # msisdns are currently always ThreepidBehaviour.REMOTE
+        # msisdns are currently always verified via the IS
         if medium == "msisdn":
             if not self.hs.config.registration.account_threepid_delegate_msisdn:
                 raise SynapseError(
@@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
                 threepid_creds,
             )
         elif medium == "email":
-            if (
-                self.hs.config.email.threepid_behaviour_email
-                == ThreepidBehaviour.REMOTE
-            ):
-                assert self.hs.config.registration.account_threepid_delegate_email
-                threepid = await identity_handler.threepid_from_creds(
-                    self.hs.config.registration.account_threepid_delegate_email,
-                    threepid_creds,
-                )
-            elif (
-                self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            ):
+            if self.hs.config.email.can_verify_email:
                 threepid = None
                 row = await self.store.get_threepid_validation_session(
                     medium,
@@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
         _BaseThreepidAuthChecker.__init__(self, hs)
 
     def is_enabled(self) -> bool:
-        return self.hs.config.email.threepid_behaviour_email in (
-            ThreepidBehaviour.REMOTE,
-            ThreepidBehaviour.LOCAL,
-        )
+        return self.hs.config.email.can_verify_email
 
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 776ed43f03..3c35b1d2c7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -79,6 +79,7 @@ from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
 from synapse.util.metrics import Measure
+from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -479,6 +480,14 @@ class MatrixFederationHttpClient:
             RequestSendFailed: If there were problems connecting to the
                 remote, due to e.g. DNS failures, connection timeouts etc.
         """
+        # Validate server name and log if it is an invalid destination, this is
+        # partially to help track down code paths where we haven't validated before here
+        try:
+            parse_and_validate_server_name(request.destination)
+        except ValueError:
+            logger.exception(f"Invalid destination: {request.destination}.")
+            raise FederationDeniedError(request.destination)
+
         if timeout:
             _sec_timeout = timeout / 1000
         else:
@@ -731,8 +740,11 @@ class MatrixFederationHttpClient:
         Returns:
             A list of headers to be added as "Authorization:" headers
         """
-        if destination is None and destination_is is None:
-            raise ValueError("destination and destination_is cannot both be None!")
+        if not destination and not destination_is:
+            raise ValueError(
+                "At least one of the arguments destination and destination_is "
+                "must be a nonempty bytestring."
+            )
 
         request: JsonDict = {
             "method": method.decode("ascii"),
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e3dcc3f3dd..6068a94b40 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,7 +33,6 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
-    TypeVar,
     Union,
 )
 
@@ -58,11 +57,13 @@ from synapse.api.errors import (
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.config.homeserver import HomeServerConfig
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
 from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
 from synapse.util import json_encoder
 from synapse.util.caches import intern_dict
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -93,77 +94,16 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-_cancellable_method_names = frozenset(
-    {
-        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
-        # methods
-        "on_GET",
-        "on_PUT",
-        "on_POST",
-        "on_DELETE",
-        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
-        # methods
-        "_async_render_GET",
-        "_async_render_PUT",
-        "_async_render_POST",
-        "_async_render_DELETE",
-        "_async_render_OPTIONS",
-        # `ReplicationEndpoint` methods
-        "_handle_request",
-    }
-)
-
-
-def cancellable(method: F) -> F:
-    """Marks a servlet method as cancellable.
-
-    Methods with this decorator will be cancelled if the client disconnects before we
-    finish processing the request.
-
-    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
-    the method. The `cancel()` call will propagate down to the `Deferred` that is
-    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
-    propagate up, as per normal exception handling.
-
-    Before applying this decorator to a new endpoint, you MUST recursively check
-    that all `await`s in the function are on `async` functions or `Deferred`s that
-    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
-    premature logging context closure, to stuck requests, to database corruption.
-
-    Usage:
-        class SomeServlet(RestServlet):
-            @cancellable
-            async def on_GET(self, request: SynapseRequest) -> ...:
-                ...
-    """
-    if method.__name__ not in _cancellable_method_names and not any(
-        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
-    ):
-        raise ValueError(
-            "@cancellable decorator can only be applied to servlet methods."
-        )
-
-    method.cancellable = True  # type: ignore[attr-defined]
-    return method
-
-
-def is_method_cancellable(method: Callable[..., Any]) -> bool:
-    """Checks whether a servlet method has the `@cancellable` flag."""
-    return getattr(method, "cancellable", False)
-
-
-def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+def return_json_error(
+    f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
+) -> None:
     """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
         # mypy doesn't understand that f.check asserts the type.
         exc: SynapseError = f.value  # type: ignore
         error_code = exc.code
-        error_dict = exc.error_dict()
-
+        error_dict = exc.error_dict(config)
         logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
     elif f.check(CancelledError):
         error_code = HTTP_STATUS_REQUEST_CANCELLED
@@ -387,7 +327,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
-            request.is_render_cancellable = is_method_cancellable(method_handler)
+            request.is_render_cancellable = is_function_cancellable(method_handler)
 
             raw_callback_return = method_handler(request)
 
@@ -450,7 +390,7 @@ class DirectServeJsonResource(_AsyncResource):
         request: SynapseRequest,
     ) -> None:
         """Implements _AsyncResource._send_error_response"""
-        return_json_error(f, request)
+        return_json_error(f, request, None)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -549,7 +489,7 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
-        request.is_render_cancellable = is_method_cancellable(callback)
+        request.is_render_cancellable = is_function_cancellable(callback)
 
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
@@ -575,6 +515,14 @@ class JsonResource(DirectServeJsonResource):
 
         return callback_return
 
+    def _send_error_response(
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response"""
+        return_json_error(f, request, self.hs.config)
+
 
 class DirectServeHtmlResource(_AsyncResource):
     """A resource that will call `self._async_on_<METHOD>` on new requests,
@@ -928,6 +876,17 @@ def set_cors_headers(request: Request) -> None:
     )
 
 
+def set_corp_headers(request: Request) -> None:
+    """Set the CORP headers so that javascript running in a web browsers can
+    embed the resource returned from this request when their client requires
+    the `Cross-Origin-Embedder-Policy: require-corp` header.
+
+    Args:
+        request: The http request to add the CORP header to.
+    """
+    request.setHeader(b"Cross-Origin-Resource-Policy", b"cross-origin")
+
+
 def respond_with_html(request: Request, code: int, html: str) -> None:
     """
     Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..26aaabfb34 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,9 +23,12 @@ from typing import (
     Optional,
     Sequence,
     Tuple,
+    Type,
+    TypeVar,
     overload,
 )
 
+from pydantic import BaseModel, ValidationError
 from typing_extensions import Literal
 
 from twisted.web.server import Request
@@ -694,6 +697,28 @@ def parse_json_object_from_request(
     return content
 
 
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+    request: Request, model_type: Type[Model]
+) -> Model:
+    """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+    validate using the given pydantic model.
+
+    Raises:
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_object_from_request(request, allow_empty_body=False)
+    try:
+        instance = model_type.parse_obj(content)
+    except ValidationError as e:
+        raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
+
+    return instance
+
+
 def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..1155f3f610 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -226,7 +226,7 @@ class SynapseRequest(Request):
 
             # If this is a request where the target user doesn't match the user who
             # authenticated (e.g. and admin is puppetting a user) then we return both.
-            if self._requester.user.to_string() != authenticated_entity:
+            if requester != authenticated_entity:
                 return requester, authenticated_entity
 
             return requester, None
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 903ec40c86..482316a1ff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -84,14 +84,13 @@ the function becomes the operation name for the span.
        return something_usual_and_useful
 
 
-Operation names can be explicitly set for a function by passing the
-operation name to ``trace``
+Operation names can be explicitly set for a function by using ``trace_with_opname``:
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace
+   from synapse.logging.opentracing import trace_with_opname
 
-   @trace(opname="a_better_operation_name")
+   @trace_with_opname("a_better_operation_name")
    def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
@@ -164,6 +163,7 @@ Gotchas
   with an active span?
 """
 import contextlib
+import enum
 import inspect
 import logging
 import re
@@ -173,6 +173,7 @@ from typing import (
     Any,
     Callable,
     Collection,
+    ContextManager,
     Dict,
     Generator,
     Iterable,
@@ -182,6 +183,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -268,7 +271,7 @@ try:
 
         _reporter: Reporter = attr.Factory(Reporter)
 
-        def set_process(self, *args, **kwargs):
+        def set_process(self, *args: Any, **kwargs: Any) -> None:
             return self._reporter.set_process(*args, **kwargs)
 
         def report_span(self, span: "opentracing.Span") -> None:
@@ -307,6 +310,19 @@ class SynapseTags:
     # The name of the external cache
     CACHE_NAME = "cache.name"
 
+    # Used to tag function arguments
+    #
+    # Tag a named arg. The name of the argument should be appended to this prefix.
+    FUNC_ARG_PREFIX = "ARG."
+    # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+    FUNC_ARGS = "args"
+    # Tag keyword args
+    FUNC_KWARGS = "kwargs"
+
+    # Some intermediate result that's interesting to the function. The label for
+    # the result should be appended to this prefix.
+    RESULT_PREFIX = "RESULT."
+
 
 class SynapseBaggage:
     FORCE_TRACING = "synapse-force-tracing"
@@ -319,11 +335,16 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
 
 # Util methods
 
-Sentinel = object()
+
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
 
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -339,22 +360,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
-        message (str): Message which fills in "There was no active span when trying to %s"
+        message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -402,7 +444,7 @@ def init_tracer(hs: "HomeServer") -> None:
     config = JaegerConfig(
         config=hs.config.tracing.jaeger_config,
         service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
-        scope_manager=LogContextScopeManager(hs.config),
+        scope_manager=LogContextScopeManager(),
         metrics_factory=PrometheusMetricsFactory(),
     )
 
@@ -451,16 +493,16 @@ def whitelisted_homeserver(destination: str) -> bool:
 
 # Could use kwargs but I want these to be explicit
 def start_active_span(
-    operation_name,
-    child_of=None,
-    references=None,
-    tags=None,
-    start_time=None,
-    ignore_active_span=False,
-    finish_on_close=True,
+    operation_name: str,
+    child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
+    references: Optional[List["opentracing.Reference"]] = None,
+    tags: Optional[Dict[str, str]] = None,
+    start_time: Optional[float] = None,
+    ignore_active_span: bool = False,
+    finish_on_close: bool = True,
     *,
-    tracer=None,
-):
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -493,12 +535,12 @@ def start_active_span(
 def start_active_span_follows_from(
     operation_name: str,
     contexts: Collection,
-    child_of=None,
+    child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None,
     start_time: Optional[float] = None,
     *,
-    inherit_force_tracing=False,
-    tracer=None,
-):
+    inherit_force_tracing: bool = False,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -540,7 +582,7 @@ def start_active_span_from_edu(
     edu_content: Dict[str, Any],
     operation_name: str,
     references: Optional[List["opentracing.Reference"]] = None,
-    tags: Optional[Dict] = None,
+    tags: Optional[Dict[str, str]] = None,
     start_time: Optional[float] = None,
     ignore_active_span: bool = False,
     finish_on_close: bool = True,
@@ -617,23 +659,27 @@ def set_operation_name(operation_name: str) -> None:
 
 
 @only_if_tracing
-def force_tracing(span=Sentinel) -> None:
+def force_tracing(
+    span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel
+) -> None:
     """Force sampling for the active/given span and its children.
 
     Args:
         span: span to force tracing for. By default, the active span.
     """
-    if span is Sentinel:
-        span = opentracing.tracer.active_span
-    if span is None:
+    if isinstance(span, _Sentinel):
+        span_to_trace = opentracing.tracer.active_span
+    else:
+        span_to_trace = span
+    if span_to_trace is None:
         logger.error("No active span in force_tracing")
         return
 
-    span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
+    span_to_trace.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
     # also set a bit of baggage, so that we have a way of figuring out if
     # it is enabled later
-    span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
+    span_to_trace.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
 
 
 def is_context_forced_tracing(
@@ -709,7 +755,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -789,92 +837,155 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace(func=None, opname=None):
+def _custom_sync_async_decorator(
+    func: Callable[P, R],
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function.
-    Sets the operation name to that of the function's or that given
-    as operation_name. See the module's doc string for usage
-    examples.
+    Decorates a function that is sync or async (coroutines), or that returns a Twisted
+    `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+    Example usage:
+    ```py
+    # Decorator to time the function and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+            start_ts = time.time()
+            try:
+                yield
+            finally:
+                end_ts = time.time()
+                duration = end_ts - start_ts
+                logger.info("%s took %s seconds", func.__name__, duration)
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+    ```
+
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func):
-        if opentracing is None:
-            return func  # type: ignore[unreachable]
+    if inspect.iscoroutinefunction(func):
 
-        _opname = opname if opname else func.__name__
+        @wraps(func)
+        async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)  # type: ignore[misc]
 
-        if inspect.iscoroutinefunction(func):
+    else:
+        # The other case here handles both sync functions and those
+        # decorated with inlineDeferred.
+        @wraps(func)
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
 
-            @wraps(func)
-            async def _trace_inner(*args, **kwargs):
-                with start_active_span(_opname):
-                    return await func(*args, **kwargs)
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
 
-        else:
-            # The other case here handles both sync functions and those
-            # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args, **kwargs):
-                scope = start_active_span(_opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result):
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result):
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
 
-                    return result
+                    result.addCallbacks(call_back, err_back)
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
 
-        return _trace_inner
+                    scope.__exit__(None, None, None)
 
-    if func:
-        return decorator(func)
-    else:
-        return decorator
+                return result
+
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
+
+    return _wrapper  # type: ignore[return-value]
+
+
+def trace_with_opname(
+    opname: str,
+    *,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+    See the module's doc string for usage examples.
+    """
+
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
+        with start_active_span(opname, tracer=tracer):
+            yield
+
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+        if not opentracing:
+            return func
+
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+
+    return _decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+    """
+    Decorator to trace a function.
+    Sets the operation name to that of the function's name.
+    See the module's doc string for usage examples.
+    """
+
+    return trace_with_opname(func.__name__)(func)
 
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
+
+    Args:
+        func: `func` is assumed to be a method taking a `self` parameter, or a
+            `classmethod` taking a `cls` parameter. In either case, a tag is not
+            created for this parameter.
     """
 
     if not opentracing:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
-        for i, arg in enumerate(argspec.args[1:]):
-            set_tag("ARG_" + arg, args[i])  # type: ignore[index]
-        set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
-        return func(*args, **kwargs)
-
-    return _tag_args_inner
+        # We use `[1:]` to skip the `self` object reference and `start=1` to
+        # make the index line up with `argspec.args`.
+        #
+        # FIXME: We could update this to handle any type of function by ignoring the
+        #   first argument only if it's named `self` or `cls`. This isn't fool-proof
+        #   but handles the idiomatic cases.
+        for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
+            set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+        set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :]))  # type: ignore[index]
+        set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
+        yield
+
+    return _custom_sync_async_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index a26a1a58e7..10877bdfc5 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -16,11 +16,15 @@ import logging
 from types import TracebackType
 from typing import Optional, Type
 
-from opentracing import Scope, ScopeManager
+from opentracing import Scope, ScopeManager, Span
 
 import twisted
 
-from synapse.logging.context import current_context, nested_logging_context
+from synapse.logging.context import (
+    LoggingContext,
+    current_context,
+    nested_logging_context,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -35,11 +39,11 @@ class LogContextScopeManager(ScopeManager):
     but currently that doesn't work due to https://twistedmatrix.com/trac/ticket/10301.
     """
 
-    def __init__(self, config):
+    def __init__(self) -> None:
         pass
 
     @property
-    def active(self):
+    def active(self) -> Optional[Scope]:
         """
         Returns the currently active Scope which can be used to access the
         currently active Scope.span.
@@ -48,19 +52,18 @@ class LogContextScopeManager(ScopeManager):
         Tracer.start_active_span() time.
 
         Return:
-            (Scope) : the Scope that is active, or None if not
-            available.
+            The Scope that is active, or None if not available.
         """
         ctx = current_context()
         return ctx.scope
 
-    def activate(self, span, finish_on_close):
+    def activate(self, span: Span, finish_on_close: bool) -> Scope:
         """
         Makes a Span active.
         Args
-            span (Span): the span that should become active.
-            finish_on_close (Boolean): whether Span should be automatically
-                finished when Scope.close() is called.
+            span: the span that should become active.
+            finish_on_close: whether Span should be automatically finished when
+                Scope.close() is called.
 
         Returns:
             Scope to control the end of the active period for
@@ -112,8 +115,8 @@ class _LogContextScope(Scope):
     def __init__(
         self,
         manager: LogContextScopeManager,
-        span,
-        logcontext,
+        span: Span,
+        logcontext: LoggingContext,
         enter_logcontext: bool,
         finish_on_close: bool,
     ):
@@ -121,13 +124,13 @@ class _LogContextScope(Scope):
         Args:
             manager:
                 the manager that is responsible for this scope.
-            span (Span):
+            span:
                 the opentracing span which this scope represents the local
                 lifetime for.
-            logcontext (LogContext):
-                the logcontext to which this scope is attached.
+            logcontext:
+                the log context to which this scope is attached.
             enter_logcontext:
-                if True the logcontext will be exited when the scope is finished
+                if True the log context will be exited when the scope is finished
             finish_on_close:
                 if True finish the span when the scope is closed
         """
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 496fce2ecc..c3d3daf877 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -46,12 +46,12 @@ from twisted.python.threadpool import ThreadPool
 
 # This module is imported for its side effects; flake8 needn't warn that it's unused.
 import synapse.metrics._reactor_metrics  # noqa: F401
-from synapse.metrics._exposition import (
+from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
+from synapse.metrics._legacy_exposition import (
     MetricsResource,
     generate_latest,
     start_http_server,
 )
-from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
 from synapse.metrics._types import Collector
 from synapse.util import SYNAPSE_VERSION
 
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_legacy_exposition.py
index 353d0a63b6..ff640a49af 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_legacy_exposition.py
@@ -80,7 +80,27 @@ def sample_line(line: Sample, name: str) -> str:
     return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
+# Mapping from new metric names to legacy metric names.
+# We translate these back to their old names when exposing them through our
+# legacy vendored exporter.
+# Only this legacy exposition module applies these name changes.
+LEGACY_METRIC_NAMES = {
+    "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits",
+    "synapse_util_caches_cache_size": "synapse_util_caches_cache:size",
+    "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size",
+    "synapse_util_caches_cache_total": "synapse_util_caches_cache:total",
+    "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size",
+    "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits",
+    "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size",
+    "synapse_util_caches_response_cache_total": "synapse_util_caches_response_cache:total",
+}
+
+
 def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
+    """
+    Generate metrics in legacy format. Modern metrics are generated directly
+    by prometheus-client.
+    """
 
     # Trigger the cache metrics to be rescraped, which updates the common
     # metrics but do not produce metrics themselves
@@ -94,7 +114,8 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # No samples, don't bother.
             continue
 
-        mname = metric.name
+        # Translate to legacy metric name if it has one.
+        mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name)
         mnewname = metric.name
         mtype = metric.type
 
@@ -124,7 +145,7 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
         om_samples: Dict[str, List[str]] = {}
         for s in metric.samples:
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     # OpenMetrics specific sample, put in a gauge at the end.
                     # (these come from gaugehistograms which don't get renamed,
                     # so no need to faff with mnewname)
@@ -140,12 +161,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             if emit_help:
                 output.append(
                     "# HELP {}{} {}\n".format(
-                        metric.name,
+                        mname,
                         suffix,
                         metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                     )
                 )
-            output.append(f"# TYPE {metric.name}{suffix} gauge\n")
+            output.append(f"# TYPE {mname}{suffix} gauge\n")
             output.extend(lines)
 
         # Get rid of the weird colon things while we're at it
@@ -170,11 +191,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # Get rid of the OpenMetrics specific samples (we should already have
             # dealt with them above anyway.)
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     break
             else:
+                sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name)
                 output.append(
-                    sample_line(s, s.name.replace(":total", "").replace(":", "_"))
+                    sample_line(s, sample_name.replace(":total", "").replace(":", "_"))
                 )
 
     return "".join(output).encode("utf-8")
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index eef3462e10..7a1516d3a8 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -235,7 +235,7 @@ def run_as_background_process(
                         f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
                     )
                 else:
-                    ctx = nullcontext()
+                    ctx = nullcontext()  # type: ignore[assignment]
                 with ctx:
                     return await func(*args, **kwargs)
             except Exception:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 30b2aeffdd..87ba154cb7 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -35,6 +35,7 @@ from typing_extensions import ParamSpec
 from twisted.internet import defer
 from twisted.web.resource import Resource
 
+from synapse.api import errors
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.presence_router import (
@@ -115,6 +116,7 @@ from synapse.types import (
     JsonDict,
     JsonMapping,
     Requester,
+    RoomAlias,
     StateMap,
     UserID,
     UserInfo,
@@ -163,6 +165,7 @@ __all__ = [
     "EventBase",
     "StateMap",
     "ProfileInfo",
+    "RoomAlias",
     "UserProfile",
 ]
 
@@ -926,10 +929,12 @@ class ModuleApi:
         room_id: str,
         new_membership: str,
         content: Optional[JsonDict] = None,
+        remote_room_hosts: Optional[List[str]] = None,
     ) -> EventBase:
         """Updates the membership of a user to the given value.
 
         Added in Synapse v1.46.0.
+        Changed in Synapse v1.65.0: Added the 'remote_room_hosts' parameter.
 
         Args:
             sender: The user performing the membership change. Must be a user local to
@@ -943,6 +948,7 @@ class ModuleApi:
                 https://spec.matrix.org/unstable/client-server-api/#mroommember for the
                 list of allowed values.
             content: Additional values to include in the resulting event's content.
+            remote_room_hosts: Remote servers to use for remote joins/knocks/etc.
 
         Returns:
             The newly created membership event.
@@ -1002,15 +1008,12 @@ class ModuleApi:
             room_id=room_id,
             action=new_membership,
             content=content,
+            remote_room_hosts=remote_room_hosts,
         )
 
         # Try to retrieve the resulting event.
         event = await self._hs.get_datastores().main.get_event(event_id)
 
-        # update_membership is supposed to always return after the event has been
-        # successfully persisted.
-        assert event is not None
-
         return event
 
     async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
@@ -1449,6 +1452,81 @@ class ModuleApi:
             start_timestamp, end_timestamp
         )
 
+    async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]:
+        """
+        Get the room ID associated with a room alias.
+
+        Added in Synapse v1.65.0.
+
+        Args:
+            room_alias: The alias to look up.
+
+        Returns:
+            A tuple of:
+                The room ID (str).
+                Hosts likely to be participating in the room ([str]).
+
+        Raises:
+            SynapseError if room alias is invalid or could not be found.
+        """
+        alias = RoomAlias.from_string(room_alias)
+        (room_id, hosts) = await self._hs.get_room_member_handler().lookup_room_alias(
+            alias
+        )
+
+        return room_id.to_string(), hosts
+
+    async def create_room(
+        self,
+        user_id: str,
+        config: JsonDict,
+        ratelimit: bool = True,
+        creator_join_profile: Optional[JsonDict] = None,
+    ) -> Tuple[str, Optional[str]]:
+        """Creates a new room.
+
+        Added in Synapse v1.65.0.
+
+        Args:
+            user_id:
+                The user who requested the room creation.
+            config : A dict of configuration options. See "Request body" of:
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+            ratelimit: set to False to disable the rate limiter for this specific operation.
+
+            creator_join_profile:
+                Set to override the displayname and avatar for the creating
+                user in this room. If unset, displayname and avatar will be
+                derived from the user's profile. If set, should contain the
+                values to go in the body of the 'join' event (typically
+                `avatar_url` and/or `displayname`.
+
+        Returns:
+                A tuple containing: 1) the room ID (str), 2) if an alias was requested,
+                the room alias (str), otherwise None if no alias was requested.
+
+        Raises:
+            ResourceLimitError if server is blocked to some resource being
+            exceeded.
+            RuntimeError if the user_id does not refer to a local user.
+            SynapseError if the user_id is invalid, room ID couldn't be stored, or
+            something went horribly wrong.
+        """
+        if not self.is_mine(user_id):
+            raise RuntimeError(
+                "Tried to create a room as a user that isn't local to this homeserver",
+            )
+
+        requester = create_requester(user_id)
+        room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
+            requester=requester,
+            config=config,
+            ratelimit=ratelimit,
+            creator_join_profile=creator_join_profile,
+        )
+
+        return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 54b0ec4b97..c42bb8266a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,6 +228,7 @@ class Notifier:
 
         # Called when there are new things to stream over replication
         self.replication_callbacks: List[Callable[[], None]] = []
+        self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
 
         self._federation_client = hs.get_federation_http_client()
 
@@ -280,6 +281,19 @@ class Notifier:
         """
         self.replication_callbacks.append(cb)
 
+    def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
+        """Add a callback that will be called when a user joins a room.
+
+        This only fires on genuine membership changes, e.g. "invite" -> "join".
+        Membership transitions like "join" -> "join" (for e.g. displayname changes) do
+        not trigger the callback.
+
+        When called, the callback receives two arguments: the event ID and the room ID.
+        It should *not* return a Deferred - if it needs to do any asynchronous work, a
+        background thread should be started and wrapped with run_as_background_process.
+        """
+        self._new_join_in_room_callbacks.append(cb)
+
     async def on_new_room_event(
         self,
         event: EventBase,
@@ -723,6 +737,10 @@ class Notifier:
         for cb in self.replication_callbacks:
             cb()
 
+    def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
+        for cb in self._new_join_in_room_callbacks:
+            cb(event_id, room_id)
+
     def notify_remote_server_up(self, server: str) -> None:
         """Notify any replication that a remote server has come back up"""
         # We call federation_sender directly rather than registering as a
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 819bc9e9b6..440205e80c 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -14,128 +14,235 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
-from typing import Any, Dict, List
-
-from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+"""
+Push rules is the system used to determine which events trigger a push (and a
+bump in notification counts).
+
+This consists of a list of "push rules" for each user, where a push rule is a
+pair of "conditions" and "actions". When a user receives an event Synapse
+iterates over the list of push rules until it finds one where all the conditions
+match the event, at which point "actions" describe the outcome (e.g. notify,
+highlight, etc).
+
+Push rules are split up into 5 different "kinds" (aka "priority classes"), which
+are run in order:
+    1. Override — highest priority rules, e.g. always ignore notices
+    2. Content — content specific rules, e.g. @ notifications
+    3. Room — per room rules, e.g. enable/disable notifications for all messages
+       in a room
+    4. Sender — per sender rules, e.g. never notify for messages from a given
+       user
+    5. Underride — the lowest priority "default" rules, e.g. notify for every
+       message.
+
+The set of "base rules" are the list of rules that every user has by default. A
+user can modify their copy of the push rules in one of three ways:
+
+    1. Adding a new push rule of a certain kind
+    2. Changing the actions of a base rule
+    3. Enabling/disabling a base rule.
+
+The base rules are split into whether they come before or after a particular
+kind, so the order of push rule evaluation would be: base rules for before
+"override" kind, user defined "override" rules, base rules after "override"
+kind, etc, etc.
+"""
+
+import itertools
+import logging
+from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union
+
+import attr
+
+from synapse.config.experimental import ExperimentalConfig
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class PushRule:
+    """A push rule
+
+    Attributes:
+        rule_id: a unique ID for this rule
+        priority_class: what "kind" of push rule this is (see
+            `PRIORITY_CLASS_MAP` for mapping between int and kind)
+        conditions: the sequence of conditions that all need to match
+        actions: the actions to apply if all conditions are met
+        default: is this a base rule?
+        default_enabled: is this enabled by default?
+    """
 
+    rule_id: str
+    priority_class: int
+    conditions: Sequence[Mapping[str, str]]
+    actions: Sequence[Union[str, Mapping]]
+    default: bool = False
+    default_enabled: bool = True
 
-def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
-    """Combine the list of rules set by the user with the default push rules
 
-    Args:
-        rawrules: The rules the user has modified or set.
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class PushRules:
+    """A collection of push rules for an account.
 
-    Returns:
-        A new list with the rules set by the user combined with the defaults.
+    Can be iterated over, producing push rules in priority order.
     """
-    ruleslist = []
 
-    # Grab the base rules that the user has modified.
-    # The modified base rules have a priority_class of -1.
-    modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
+    # A mapping from rule ID to push rule that overrides a base rule. These will
+    # be returned instead of the base rule.
+    overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict)
+
+    # The following stores the custom push rules at each priority class.
+    #
+    # We keep these separate (rather than combining into one big list) to avoid
+    # copying the base rules around all the time.
+    override: List[PushRule] = attr.Factory(list)
+    content: List[PushRule] = attr.Factory(list)
+    room: List[PushRule] = attr.Factory(list)
+    sender: List[PushRule] = attr.Factory(list)
+    underride: List[PushRule] = attr.Factory(list)
+
+    def __iter__(self) -> Iterator[PushRule]:
+        # When iterating over the push rules we need to return the base rules
+        # interspersed at the correct spots.
+        for rule in itertools.chain(
+            BASE_PREPEND_OVERRIDE_RULES,
+            self.override,
+            BASE_APPEND_OVERRIDE_RULES,
+            self.content,
+            BASE_APPEND_CONTENT_RULES,
+            self.room,
+            self.sender,
+            self.underride,
+            BASE_APPEND_UNDERRIDE_RULES,
+        ):
+            # Check if a base rule has been overriden by a custom rule. If so
+            # return that instead.
+            override_rule = self.overriden_base_rules.get(rule.rule_id)
+            if override_rule:
+                yield override_rule
+            else:
+                yield rule
+
+    def __len__(self) -> int:
+        # The length is mostly used by caches to get a sense of "size" / amount
+        # of memory this object is using, so we only count the number of custom
+        # rules.
+        return (
+            len(self.overriden_base_rules)
+            + len(self.override)
+            + len(self.content)
+            + len(self.room)
+            + len(self.sender)
+            + len(self.underride)
+        )
 
-    # Remove the modified base rules from the list, They'll be added back
-    # in the default positions in the list.
-    rawrules = [r for r in rawrules if r["priority_class"] >= 0]
 
-    # shove the server default rules for each kind onto the end of each
-    current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class FilteredPushRules:
+    """A wrapper around `PushRules` that filters out disabled experimental push
+    rules, and includes the "enabled" state for each rule when iterated over.
+    """
 
-    ruleslist.extend(
-        make_base_prepend_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-        )
-    )
+    push_rules: PushRules
+    enabled_map: Dict[str, bool]
+    experimental_config: ExperimentalConfig
 
-    for r in rawrules:
-        if r["priority_class"] < current_prio_class:
-            while r["priority_class"] < current_prio_class:
-                ruleslist.extend(
-                    make_base_append_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
-                        modified_base_rules,
-                    )
-                )
-                current_prio_class -= 1
-                if current_prio_class > 0:
-                    ruleslist.extend(
-                        make_base_prepend_rules(
-                            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
-                            modified_base_rules,
-                        )
-                    )
-
-        ruleslist.append(r)
-
-    while current_prio_class > 0:
-        ruleslist.extend(
-            make_base_append_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-            )
-        )
-        current_prio_class -= 1
-        if current_prio_class > 0:
-            ruleslist.extend(
-                make_base_prepend_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-                )
-            )
+    def __iter__(self) -> Iterator[Tuple[PushRule, bool]]:
+        for rule in self.push_rules:
+            if not _is_experimental_rule_enabled(
+                rule.rule_id, self.experimental_config
+            ):
+                continue
 
-    return ruleslist
+            enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled)
 
+            yield rule, enabled
 
-def make_base_append_rules(
-    kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
-) -> List[Dict[str, Any]]:
-    rules = []
+    def __len__(self) -> int:
+        return len(self.push_rules)
 
-    if kind == "override":
-        rules = BASE_APPEND_OVERRIDE_RULES
-    elif kind == "underride":
-        rules = BASE_APPEND_UNDERRIDE_RULES
-    elif kind == "content":
-        rules = BASE_APPEND_CONTENT_RULES
 
-    # Copy the rules before modifying them
-    rules = copy.deepcopy(rules)
-    for r in rules:
-        # Only modify the actions, keep the conditions the same.
-        assert isinstance(r["rule_id"], str)
-        modified = modified_base_rules.get(r["rule_id"])
-        if modified:
-            r["actions"] = modified["actions"]
+DEFAULT_EMPTY_PUSH_RULES = PushRules()
 
-    return rules
 
+def compile_push_rules(rawrules: List[PushRule]) -> PushRules:
+    """Given a set of custom push rules return a `PushRules` instance (which
+    includes the base rules).
+    """
+
+    if not rawrules:
+        # Fast path to avoid allocating empty lists when there are no custom
+        # rules for the user.
+        return DEFAULT_EMPTY_PUSH_RULES
+
+    rules = PushRules()
 
-def make_base_prepend_rules(
-    kind: str,
-    modified_base_rules: Dict[str, Dict[str, Any]],
-) -> List[Dict[str, Any]]:
-    rules = []
+    for rule in rawrules:
+        # We need to decide which bucket each custom push rule goes into.
 
-    if kind == "override":
-        rules = BASE_PREPEND_OVERRIDE_RULES
+        # If it has the same ID as a base rule then it overrides that...
+        overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id)
+        if overriden_base_rule:
+            rules.overriden_base_rules[rule.rule_id] = attr.evolve(
+                overriden_base_rule, actions=rule.actions
+            )
+            continue
+
+        # ... otherwise it gets added to the appropriate priority class bucket
+        collection: List[PushRule]
+        if rule.priority_class == 5:
+            collection = rules.override
+        elif rule.priority_class == 4:
+            collection = rules.content
+        elif rule.priority_class == 3:
+            collection = rules.room
+        elif rule.priority_class == 2:
+            collection = rules.sender
+        elif rule.priority_class == 1:
+            collection = rules.underride
+        elif rule.priority_class <= 0:
+            logger.info(
+                "Got rule with priority class less than zero, but doesn't override a base rule: %s",
+                rule,
+            )
+            continue
+        else:
+            # We log and continue here so as not to break event sending
+            logger.error("Unknown priority class: %", rule.priority_class)
+            continue
 
-    # Copy the rules before modifying them
-    rules = copy.deepcopy(rules)
-    for r in rules:
-        # Only modify the actions, keep the conditions the same.
-        assert isinstance(r["rule_id"], str)
-        modified = modified_base_rules.get(r["rule_id"])
-        if modified:
-            r["actions"] = modified["actions"]
+        collection.append(rule)
 
     return rules
 
 
-# We have to annotate these types, otherwise mypy infers them as
-# `List[Dict[str, Sequence[Collection[str]]]]`.
-BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/content/.m.rule.contains_user_name",
-        "conditions": [
+def _is_experimental_rule_enabled(
+    rule_id: str, experimental_config: ExperimentalConfig
+) -> bool:
+    """Used by `FilteredPushRules` to filter out experimental rules when they
+    have not been enabled.
+    """
+    if (
+        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
+        and not experimental_config.msc3786_enabled
+    ):
+        return False
+    if (
+        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+        and not experimental_config.msc3772_enabled
+    ):
+        return False
+    return True
+
+
+BASE_APPEND_CONTENT_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["content"],
+        rule_id="global/content/.m.rule.contains_user_name",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.body",
@@ -143,29 +250,33 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
                 "pattern_type": "user_localpart",
             }
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight"},
         ],
-    }
+    )
 ]
 
 
-BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/override/.m.rule.master",
-        "enabled": False,
-        "conditions": [],
-        "actions": ["dont_notify"],
-    }
+BASE_PREPEND_OVERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.master",
+        default_enabled=False,
+        conditions=[],
+        actions=["dont_notify"],
+    )
 ]
 
 
-BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/override/.m.rule.suppress_notices",
-        "conditions": [
+BASE_APPEND_OVERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.suppress_notices",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.msgtype",
@@ -173,13 +284,15 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_suppress_notices",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
     # otherwise invites will be matched by .m.rule.member_event
-    {
-        "rule_id": "global/override/.m.rule.invite_for_me",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.invite_for_me",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -195,21 +308,23 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
             # Match the requester's MXID.
             {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # Will we sometimes want to know about people joining and leaving?
     # Perhaps: if so, this could be expanded upon. Seems the most usual case
     # is that we don't though. We add this override rule so that even if
     # the room rule is set to notify, we don't get notifications about
     # join/leave/avatar/displayname events.
     # See also: https://matrix.org/jira/browse/SYN-607
-    {
-        "rule_id": "global/override/.m.rule.member_event",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.member_event",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -217,24 +332,28 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_member",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # This was changed from underride to override so it's closer in priority
     # to the content rules where the user name highlight rule lives. This
     # way a room rule is lower priority than both but a custom override rule
     # is higher priority than both.
-    {
-        "rule_id": "global/override/.m.rule.contains_display_name",
-        "conditions": [{"kind": "contains_display_name"}],
-        "actions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.contains_display_name",
+        conditions=[{"kind": "contains_display_name"}],
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight"},
         ],
-    },
-    {
-        "rule_id": "global/override/.m.rule.roomnotif",
-        "conditions": [
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.roomnotif",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.body",
@@ -247,11 +366,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_roomnotif_pl",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": True}],
-    },
-    {
-        "rule_id": "global/override/.m.rule.tombstone",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": True}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.tombstone",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -265,11 +386,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_tombstone_statekey",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": True}],
-    },
-    {
-        "rule_id": "global/override/.m.rule.reaction",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": True}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.reaction",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -277,30 +400,40 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_reaction",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # XXX: This is an experimental rule that is only enabled if msc3786_enabled
     # is enabled, if it is not the rule gets filtered out in _load_rules() in
     # PushRulesWorkerStore
-    {
-        "rule_id": "global/override/.org.matrix.msc3786.rule.room.server_acl",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
                 "pattern": "m.room.server_acl",
                 "_cache_key": "_room_server_acl",
-            }
+            },
+            {
+                "kind": "event_match",
+                "key": "state_key",
+                "pattern": "",
+                "_cache_key": "_room_server_acl_state_key",
+            },
         ],
-        "actions": [],
-    },
+        actions=[],
+    ),
 ]
 
 
-BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/underride/.m.rule.call",
-        "conditions": [
+BASE_APPEND_UNDERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.call",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -308,17 +441,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_call",
             }
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "ring"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # XXX: once m.direct is standardised everywhere, we should use it to detect
     # a DM from the user's perspective rather than this heuristic.
-    {
-        "rule_id": "global/underride/.m.rule.room_one_to_one",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.room_one_to_one",
+        conditions=[
             {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
             {
                 "kind": "event_match",
@@ -327,17 +462,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_message",
             },
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # XXX: this is going to fire for events which aren't m.room.messages
     # but are encrypted (e.g. m.call.*)...
-    {
-        "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.encrypted_room_one_to_one",
+        conditions=[
             {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
             {
                 "kind": "event_match",
@@ -346,15 +483,17 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_encrypted",
             },
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
-    {
-        "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
-        "conditions": [
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.org.matrix.msc3772.thread_reply",
+        conditions=[
             {
                 "kind": "org.matrix.msc3772.relation_match",
                 "rel_type": "m.thread",
@@ -362,11 +501,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "sender_type": "user_id",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
-    {
-        "rule_id": "global/underride/.m.rule.message",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.message",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -374,13 +515,15 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_message",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
     # XXX: this is going to fire for events which aren't m.room.messages
     # but are encrypted (e.g. m.call.*)...
-    {
-        "rule_id": "global/underride/.m.rule.encrypted",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.encrypted",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -388,11 +531,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_encrypted",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
-    {
-        "rule_id": "global/underride/.im.vector.jitsi",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.im.vector.jitsi",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -412,29 +557,27 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_is_state_event",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
 ]
 
 
 BASE_RULE_IDS = set()
 
+BASE_RULES_BY_ID: Dict[str, PushRule] = {}
+
 for r in BASE_APPEND_CONTENT_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["content"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["override"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_APPEND_OVERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["override"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7791b289e2..d1caf8a0f7 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,9 +15,19 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
-import attr
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,13 +36,12 @@ from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import CacheMetric, register_cache
-from synapse.util.caches.descriptors import lru_cache
-from synapse.util.caches.lrucache import LruCache
+from synapse.storage.state import StateFilter
+from synapse.util.caches import register_cache
 from synapse.util.metrics import measure_func
+from synapse.visibility import filter_event_for_clients_with_state
 
-from ..storage.state import StateFilter
+from .baserules import FilteredPushRules, PushRule
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
 if TYPE_CHECKING:
@@ -48,15 +57,6 @@ push_rules_state_size_counter = Counter(
     "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", ""
 )
 
-# Measures whether we use the fast path of using state deltas, or if we have to
-# recalculate from scratch
-push_rules_delta_state_cache_metric = register_cache(
-    "cache",
-    "push_rules_delta_state_cache_metric",
-    cache=[],  # Meaningless size, as this isn't a cache that stores values
-    resizable=False,
-)
-
 
 STATE_EVENT_TYPES_TO_MARK_UNREAD = {
     EventTypes.Topic,
@@ -111,10 +111,6 @@ class BulkPushRuleEvaluator:
         self.clock = hs.get_clock()
         self._event_auth_handler = hs.get_event_auth_handler()
 
-        # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
-        # time. Keyed off room_id.
-        self._rules_linearizer = Linearizer(name="rules_for_room")
-
         self.room_push_rule_cache_metrics = register_cache(
             "cache",
             "room_push_rule_cache",
@@ -126,47 +122,54 @@ class BulkPushRuleEvaluator:
         self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
 
     async def _get_rules_for_event(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, List[Dict[str, Any]]]:
-        """This gets the rules for all users in the room at the time of the event,
-        as well as the push rules for the invitee if the event is an invite.
+        self,
+        event: EventBase,
+    ) -> Dict[str, FilteredPushRules]:
+        """Get the push rules for all users who may need to be notified about
+        the event.
+
+        Note: this does not check if the user is allowed to see the event.
 
         Returns:
-            dict of user_id -> push_rules
+            Mapping of user ID to their push rules.
         """
-        room_id = event.room_id
-
-        rules_for_room_data = self._get_rules_for_room(room_id)
-        rules_for_room = RulesForRoom(
-            hs=self.hs,
-            room_id=room_id,
-            rules_for_room_cache=self._get_rules_for_room.cache,
-            room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
-            linearizer=self._rules_linearizer,
-            cached_data=rules_for_room_data,
-        )
-
-        rules_by_user = await rules_for_room.get_rules(event, context)
+        # We get the users who may need to be notified by first fetching the
+        # local users currently in the room, finding those that have push rules,
+        # and *then* checking which users are actually allowed to see the event.
+        #
+        # The alternative is to first fetch all users that were joined at the
+        # event, but that requires fetching the full state at the event, which
+        # may be expensive for large rooms with few local users.
+
+        local_users = await self.store.get_local_users_in_room(event.room_id)
+
+        # Filter out appservice users.
+        local_users = [
+            u
+            for u in local_users
+            if not self.store.get_if_app_services_interested_in_user(u)
+        ]
 
         # if this event is an invite event, we may need to run rules for the user
         # who's been invited, otherwise they won't get told they've been invited
-        if event.type == "m.room.member" and event.content["membership"] == "invite":
+        if event.type == EventTypes.Member and event.membership == Membership.INVITE:
             invited = event.state_key
-            if invited and self.hs.is_mine_id(invited):
-                rules_by_user = dict(rules_by_user)
-                rules_by_user[invited] = await self.store.get_push_rules_for_user(
-                    invited
-                )
+            if invited and self.hs.is_mine_id(invited) and invited not in local_users:
+                local_users = list(local_users)
+                local_users.append(invited)
 
-        return rules_by_user
+        rules_by_user = await self.store.bulk_get_push_rules(local_users)
 
-    @lru_cache()
-    def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
-        """Get the current RulesForRoomData object for the given room id"""
-        # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
-        # before any lookup methods get called on it as otherwise there may be
-        # a race if invalidate_all gets called (which assumes its in the cache)
-        return RulesForRoomData()
+        logger.debug("Users in room: %s", local_users)
+
+        if logger.isEnabledFor(logging.DEBUG):
+            logger.debug(
+                "Returning push rules for %r %r",
+                event.room_id,
+                list(rules_by_user.keys()),
+            )
+
+        return rules_by_user
 
     async def _get_power_levels_and_sender_level(
         self, event: EventBase, context: EventContext
@@ -195,7 +198,7 @@ class BulkPushRuleEvaluator:
         return pl_event.content if pl_event else {}, sender_level
 
     async def _get_mutual_relations(
-        self, event: EventBase, rules: Iterable[Dict[str, Any]]
+        self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]]
     ) -> Dict[str, Set[Tuple[str, str]]]:
         """
         Fetch event metadata for events which related to the same event as the given event.
@@ -225,12 +228,11 @@ class BulkPushRuleEvaluator:
 
         # Pre-filter to figure out which relation types are interesting.
         rel_types = set()
-        for rule in rules:
-            # Skip disabled rules.
-            if "enabled" in rule and not rule["enabled"]:
+        for rule, enabled in rules:
+            if not enabled:
                 continue
 
-            for condition in rule["conditions"]:
+            for condition in rule.conditions:
                 if condition["kind"] != "org.matrix.msc3772.relation_match":
                     continue
 
@@ -260,12 +262,19 @@ class BulkPushRuleEvaluator:
             # This can happen due to out of band memberships
             return
 
-        count_as_unread = _should_count_as_unread(event, context)
+        # Disable counting as unread unless the experimental configuration is
+        # enabled, as it can cause additional (unwanted) rows to be added to the
+        # event_push_actions table.
+        count_as_unread = False
+        if self.hs.config.experimental.msc2654_enabled:
+            count_as_unread = _should_count_as_unread(event, context)
 
-        rules_by_user = await self._get_rules_for_event(event, context)
-        actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+        rules_by_user = await self._get_rules_for_event(event)
+        actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
 
-        room_members = await self.store.get_joined_users_from_context(event, context)
+        room_member_count = await self.store.get_number_joined_users_in_room(
+            event.room_id
+        )
 
         (
             power_levels,
@@ -278,30 +287,36 @@ class BulkPushRuleEvaluator:
 
         evaluator = PushRuleEvaluatorForEvent(
             event,
-            len(room_members),
+            room_member_count,
             sender_power_level,
             power_levels,
             relations,
             self._relations_match_enabled,
         )
 
-        # If the event is not a state event check if any users ignore the sender.
-        if not event.is_state():
-            ignorers = await self.store.ignored_by(event.sender)
-        else:
-            ignorers = frozenset()
+        users = rules_by_user.keys()
+        profiles = await self.store.get_subset_users_in_room_with_profiles(
+            event.room_id, users
+        )
+
+        # This is a check for the case where user joins a room without being
+        # allowed to see history, and then the server receives a delayed event
+        # from before the user joined, which they should not be pushed for
+        uids_with_visibility = await filter_event_for_clients_with_state(
+            self.store, users, event, context
+        )
 
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
 
-            if uid in ignorers:
+            if uid not in uids_with_visibility:
                 continue
 
             display_name = None
-            profile_info = room_members.get(uid)
-            if profile_info:
-                display_name = profile_info.display_name
+            profile = profiles.get(uid)
+            if profile:
+                display_name = profile.display_name
 
             if not display_name:
                 # Handle the case where we are pushing a membership event to
@@ -318,15 +333,13 @@ class BulkPushRuleEvaluator:
                 # current user, it'll be added to the dict later.
                 actions_by_user[uid] = []
 
-            for rule in rules:
-                if "enabled" in rule and not rule["enabled"]:
+            for rule, enabled in rules:
+                if not enabled:
                     continue
 
-                matches = evaluator.check_conditions(
-                    rule["conditions"], uid, display_name
-                )
+                matches = evaluator.check_conditions(rule.conditions, uid, display_name)
                 if matches:
-                    actions = [x for x in rule["actions"] if x != "dont_notify"]
+                    actions = [x for x in rule.actions if x != "dont_notify"]
                     if actions and "notify" in actions:
                         # Push rules say we should notify the user of this event
                         actions_by_user[uid] = actions
@@ -346,283 +359,3 @@ MemberMap = Dict[str, Optional[EventIdMembership]]
 Rule = Dict[str, dict]
 RulesByUser = Dict[str, List[Rule]]
 StateGroup = Union[object, int]
-
-
-@attr.s(slots=True, auto_attribs=True)
-class RulesForRoomData:
-    """The data stored in the cache by `RulesForRoom`.
-
-    We don't store `RulesForRoom` directly in the cache as we want our caches to
-    *only* include data, and not references to e.g. the data stores.
-    """
-
-    # event_id -> EventIdMembership
-    member_map: MemberMap = attr.Factory(dict)
-    # user_id -> rules
-    rules_by_user: RulesByUser = attr.Factory(dict)
-
-    # The last state group we updated the caches for. If the state_group of
-    # a new event comes along, we know that we can just return the cached
-    # result.
-    # On invalidation of the rules themselves (if the user changes them),
-    # we invalidate everything and set state_group to `object()`
-    state_group: StateGroup = attr.Factory(object)
-
-    # A sequence number to keep track of when we're allowed to update the
-    # cache. We bump the sequence number when we invalidate the cache. If
-    # the sequence number changes while we're calculating stuff we should
-    # not update the cache with it.
-    sequence: int = 0
-
-    # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
-    # owned by AS's, or remote users, etc. (I.e. users we will never need to
-    # calculate push for)
-    # These never need to be invalidated as we will never set up push for
-    # them.
-    uninteresting_user_set: Set[str] = attr.Factory(set)
-
-
-class RulesForRoom:
-    """Caches push rules for users in a room.
-
-    This efficiently handles users joining/leaving the room by not invalidating
-    the entire cache for the room.
-
-    A new instance is constructed for each call to
-    `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
-    previous calls passed in.
-    """
-
-    def __init__(
-        self,
-        hs: "HomeServer",
-        room_id: str,
-        rules_for_room_cache: LruCache,
-        room_push_rule_cache_metrics: CacheMetric,
-        linearizer: Linearizer,
-        cached_data: RulesForRoomData,
-    ):
-        """
-        Args:
-            hs: The HomeServer object.
-            room_id: The room ID.
-            rules_for_room_cache: The cache object that caches these
-                RoomsForUser objects.
-            room_push_rule_cache_metrics: The metrics object
-            linearizer: The linearizer used to ensure only one thing mutates
-                the cache at a time. Keyed off room_id
-            cached_data: Cached data from previous calls to `self.get_rules`,
-                can be mutated.
-        """
-        self.room_id = room_id
-        self.is_mine_id = hs.is_mine_id
-        self.store = hs.get_datastores().main
-        self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
-
-        # Used to ensure only one thing mutates the cache at a time. Keyed off
-        # room_id.
-        self.linearizer = linearizer
-
-        self.data = cached_data
-
-        # We need to be clever on the invalidating caches callbacks, as
-        # otherwise the invalidation callback holds a reference to the object,
-        # potentially causing it to leak.
-        # To get around this we pass a function that on invalidations looks ups
-        # the RoomsForUser entry in the cache, rather than keeping a reference
-        # to self around in the callback.
-        self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
-
-    async def get_rules(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, List[Dict[str, dict]]]:
-        """Given an event context return the rules for all users who are
-        currently in the room.
-        """
-        state_group = context.state_group
-
-        if state_group and self.data.state_group == state_group:
-            logger.debug("Using cached rules for %r", self.room_id)
-            self.room_push_rule_cache_metrics.inc_hits()
-            return self.data.rules_by_user
-
-        async with self.linearizer.queue(self.room_id):
-            if state_group and self.data.state_group == state_group:
-                logger.debug("Using cached rules for %r", self.room_id)
-                self.room_push_rule_cache_metrics.inc_hits()
-                return self.data.rules_by_user
-
-            self.room_push_rule_cache_metrics.inc_misses()
-
-            ret_rules_by_user = {}
-            missing_member_event_ids = {}
-            if state_group and self.data.state_group == context.prev_group:
-                # If we have a simple delta then we can reuse most of the previous
-                # results.
-                ret_rules_by_user = self.data.rules_by_user
-                current_state_ids = context.delta_ids
-
-                push_rules_delta_state_cache_metric.inc_hits()
-            else:
-                current_state_ids = await context.get_current_state_ids()
-                push_rules_delta_state_cache_metric.inc_misses()
-            # Ensure the state IDs exist.
-            assert current_state_ids is not None
-
-            push_rules_state_size_counter.inc(len(current_state_ids))
-
-            logger.debug(
-                "Looking for member changes in %r %r", state_group, current_state_ids
-            )
-
-            # Loop through to see which member events we've seen and have rules
-            # for and which we need to fetch
-            for key in current_state_ids:
-                typ, user_id = key
-                if typ != EventTypes.Member:
-                    continue
-
-                if user_id in self.data.uninteresting_user_set:
-                    continue
-
-                if not self.is_mine_id(user_id):
-                    self.data.uninteresting_user_set.add(user_id)
-                    continue
-
-                if self.store.get_if_app_services_interested_in_user(user_id):
-                    self.data.uninteresting_user_set.add(user_id)
-                    continue
-
-                event_id = current_state_ids[key]
-
-                res = self.data.member_map.get(event_id, None)
-                if res:
-                    if res.membership == Membership.JOIN:
-                        rules = self.data.rules_by_user.get(res.user_id, None)
-                        if rules:
-                            ret_rules_by_user[res.user_id] = rules
-                    continue
-
-                # If a user has left a room we remove their push rule. If they
-                # joined then we re-add it later in _update_rules_with_member_event_ids
-                ret_rules_by_user.pop(user_id, None)
-                missing_member_event_ids[user_id] = event_id
-
-            if missing_member_event_ids:
-                # If we have some member events we haven't seen, look them up
-                # and fetch push rules for them if appropriate.
-                logger.debug("Found new member events %r", missing_member_event_ids)
-                await self._update_rules_with_member_event_ids(
-                    ret_rules_by_user, missing_member_event_ids, state_group, event
-                )
-            else:
-                # The push rules didn't change but lets update the cache anyway
-                self.update_cache(
-                    self.data.sequence,
-                    members={},  # There were no membership changes
-                    rules_by_user=ret_rules_by_user,
-                    state_group=state_group,
-                )
-
-        if logger.isEnabledFor(logging.DEBUG):
-            logger.debug(
-                "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
-            )
-        return ret_rules_by_user
-
-    async def _update_rules_with_member_event_ids(
-        self,
-        ret_rules_by_user: Dict[str, list],
-        member_event_ids: Dict[str, str],
-        state_group: Optional[int],
-        event: EventBase,
-    ) -> None:
-        """Update the partially filled rules_by_user dict by fetching rules for
-        any newly joined users in the `member_event_ids` list.
-
-        Args:
-            ret_rules_by_user: Partially filled dict of push rules. Gets
-                updated with any new rules.
-            member_event_ids: Dict of user id to event id for membership events
-                that have happened since the last time we filled rules_by_user
-            state_group: The state group we are currently computing push rules
-                for. Used when updating the cache.
-            event: The event we are currently computing push rules for.
-        """
-        sequence = self.data.sequence
-
-        members = await self.store.get_membership_from_event_ids(
-            member_event_ids.values()
-        )
-
-        # If the event is a join event then it will be in current state events
-        # map but not in the DB, so we have to explicitly insert it.
-        if event.type == EventTypes.Member:
-            for event_id in member_event_ids.values():
-                if event_id == event.event_id:
-                    members[event_id] = EventIdMembership(
-                        user_id=event.state_key, membership=event.membership
-                    )
-
-        if logger.isEnabledFor(logging.DEBUG):
-            logger.debug("Found members %r: %r", self.room_id, members.values())
-
-        joined_user_ids = {
-            entry.user_id
-            for entry in members.values()
-            if entry and entry.membership == Membership.JOIN
-        }
-
-        logger.debug("Joined: %r", joined_user_ids)
-
-        # Previously we only considered users with pushers or read receipts in that
-        # room. We can't do this anymore because we use push actions to calculate unread
-        # counts, which don't rely on the user having pushers or sent a read receipt into
-        # the room. Therefore we just need to filter for local users here.
-        user_ids = list(filter(self.is_mine_id, joined_user_ids))
-
-        rules_by_user = await self.store.bulk_get_push_rules(
-            user_ids, on_invalidate=self.invalidate_all_cb
-        )
-
-        ret_rules_by_user.update(
-            item for item in rules_by_user.items() if item[0] is not None
-        )
-
-        self.update_cache(sequence, members, ret_rules_by_user, state_group)
-
-    def update_cache(
-        self,
-        sequence: int,
-        members: MemberMap,
-        rules_by_user: RulesByUser,
-        state_group: StateGroup,
-    ) -> None:
-        if sequence == self.data.sequence:
-            self.data.member_map.update(members)
-            self.data.rules_by_user = rules_by_user
-            self.data.state_group = state_group
-
-
-@attr.attrs(slots=True, frozen=True, auto_attribs=True)
-class _Invalidation:
-    # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
-    # which means that it it is stored on the bulk_get_push_rules cache entry. In order
-    # to ensure that we don't accumulate lots of redundant callbacks on the cache entry,
-    # we need to ensure that two _Invalidation objects are "equal" if they refer to the
-    # same `cache` and `room_id`.
-    #
-    # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
-    # set `frozen=True`.
-
-    cache: LruCache
-    room_id: str
-
-    def __call__(self) -> None:
-        rules_data = self.cache.get(self.room_id, None, update_metrics=False)
-        if rules_data:
-            rules_data.sequence += 1
-            rules_data.state_group = object()
-            rules_data.member_map = {}
-            rules_data.rules_by_user = {}
-            push_rules_invalidation_counter.inc()
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 5117ef6854..73618d9234 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -18,16 +18,15 @@ from typing import Any, Dict, List, Optional
 from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
 from synapse.types import UserID
 
+from .baserules import FilteredPushRules, PushRule
+
 
 def format_push_rules_for_user(
-    user: UserID, ruleslist: List
+    user: UserID, ruleslist: FilteredPushRules
 ) -> Dict[str, Dict[str, list]]:
     """Converts a list of rawrules and a enabled map into nested dictionaries
     to match the Matrix client-server format for push rules"""
 
-    # We're going to be mutating this a lot, so do a deep copy
-    ruleslist = copy.deepcopy(ruleslist)
-
     rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
         "global": {},
         "device": {},
@@ -35,11 +34,30 @@ def format_push_rules_for_user(
 
     rules["global"] = _add_empty_priority_class_arrays(rules["global"])
 
-    for r in ruleslist:
-        template_name = _priority_class_to_template_name(r["priority_class"])
+    for r, enabled in ruleslist:
+        template_name = _priority_class_to_template_name(r.priority_class)
+
+        rulearray = rules["global"][template_name]
+
+        template_rule = _rule_to_template(r)
+        if not template_rule:
+            continue
+
+        rulearray.append(template_rule)
+
+        template_rule["enabled"] = enabled
+
+        if "conditions" not in template_rule:
+            # Not all formatted rules have explicit conditions, e.g. "room"
+            # rules omit them as they can be derived from the kind and rule ID.
+            #
+            # If the formatted rule has no conditions then we can skip the
+            # formatting of conditions.
+            continue
 
         # Remove internal stuff.
-        for c in r["conditions"]:
+        template_rule["conditions"] = copy.deepcopy(template_rule["conditions"])
+        for c in template_rule["conditions"]:
             c.pop("_cache_key", None)
 
             pattern_type = c.pop("pattern_type", None)
@@ -52,16 +70,6 @@ def format_push_rules_for_user(
             if sender_type == "user_id":
                 c["sender"] = user.to_string()
 
-        rulearray = rules["global"][template_name]
-
-        template_rule = _rule_to_template(r)
-        if template_rule:
-            if "enabled" in r:
-                template_rule["enabled"] = r["enabled"]
-            else:
-                template_rule["enabled"] = True
-            rulearray.append(template_rule)
-
     return rules
 
 
@@ -71,24 +79,24 @@ def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
     return d
 
 
-def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
-    unscoped_rule_id = None
-    if "rule_id" in rule:
-        unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
+def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
+    templaterule: Dict[str, Any]
+
+    unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
 
-    template_name = _priority_class_to_template_name(rule["priority_class"])
+    template_name = _priority_class_to_template_name(rule.priority_class)
     if template_name in ["override", "underride"]:
-        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+        templaterule = {"conditions": rule.conditions, "actions": rule.actions}
     elif template_name in ["sender", "room"]:
-        templaterule = {"actions": rule["actions"]}
-        unscoped_rule_id = rule["conditions"][0]["pattern"]
+        templaterule = {"actions": rule.actions}
+        unscoped_rule_id = rule.conditions[0]["pattern"]
     elif template_name == "content":
-        if len(rule["conditions"]) != 1:
+        if len(rule.conditions) != 1:
             return None
-        thecond = rule["conditions"][0]
+        thecond = rule.conditions[0]
         if "pattern" not in thecond:
             return None
-        templaterule = {"actions": rule["actions"]}
+        templaterule = {"actions": rule.actions}
         templaterule["pattern"] = thecond["pattern"]
     else:
         # This should not be reached unless this function is not kept in sync
@@ -97,8 +105,8 @@ def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
 
     if unscoped_rule_id:
         templaterule["rule_id"] = unscoped_rule_id
-    if "default" in rule:
-        templaterule["default"] = rule["default"]
+    if rule.default:
+        templaterule["default"] = True
     return templaterule
 
 
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 015c19b2d9..c2575ba3d9 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -860,13 +860,14 @@ class Mailer:
              A link to unsubscribe from email notifications.
         """
         params = {
-            "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
+            "access_token": self.macaroon_gen.generate_delete_pusher_token(
+                user_id, app_id, email_address
+            ),
             "app_id": app_id,
             "pushkey": email_address,
         }
 
-        # XXX: make r0 once API is stable
-        return "%s_matrix/client/unstable/pushers/remove?%s" % (
+        return "%s_synapse/client/unsubscribe?%s" % (
             self.hs.config.server.public_baseurl,
             urllib.parse.urlencode(params),
         )
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2e8a017add..3c5632cd91 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,18 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
+from typing import (
+    Any,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Pattern,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
 
 from matrix_common.regex import glob_to_regex, to_word_pattern
 
@@ -32,14 +43,14 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
 def _room_member_count(
-    ev: EventBase, condition: Dict[str, Any], room_member_count: int
+    ev: EventBase, condition: Mapping[str, Any], room_member_count: int
 ) -> bool:
     return _test_ineq_condition(condition, room_member_count)
 
 
 def _sender_notification_permission(
     ev: EventBase,
-    condition: Dict[str, Any],
+    condition: Mapping[str, Any],
     sender_power_level: int,
     power_levels: Dict[str, Union[int, Dict[str, int]]],
 ) -> bool:
@@ -54,7 +65,7 @@ def _sender_notification_permission(
     return sender_power_level >= room_notif_level
 
 
-def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
+def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
     if "is" not in condition:
         return False
     m = INEQUALITY_EXPR.match(condition["is"])
@@ -137,7 +148,7 @@ class PushRuleEvaluatorForEvent:
         self._condition_cache: Dict[str, bool] = {}
 
     def check_conditions(
-        self, conditions: List[dict], uid: str, display_name: Optional[str]
+        self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
     ) -> bool:
         """
         Returns true if a user's conditions/user ID/display name match the event.
@@ -169,7 +180,7 @@ class PushRuleEvaluatorForEvent:
         return True
 
     def matches(
-        self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
+        self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
     ) -> bool:
         """
         Returns true if a user's condition/user ID/display name match the event.
@@ -204,7 +215,7 @@ class PushRuleEvaluatorForEvent:
             #     endpoint with an unknown kind, see _rule_tuple_from_request_object.
             return True
 
-    def _event_match(self, condition: dict, user_id: str) -> bool:
+    def _event_match(self, condition: Mapping, user_id: str) -> bool:
         """
         Check an "event_match" push rule condition.
 
@@ -269,7 +280,7 @@ class PushRuleEvaluatorForEvent:
 
         return bool(r.search(body))
 
-    def _relation_match(self, condition: dict, user_id: str) -> bool:
+    def _relation_match(self, condition: Mapping, user_id: str) -> bool:
         """
         Check an "relation_match" push rule condition.
 
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8397229ccb..6661887d9f 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 from typing import Dict
 
-from synapse.api.constants import ReceiptTypes
 from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage.controllers import StorageControllers
@@ -24,30 +23,24 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(
-        user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
-    )
-
     badge = len(invites)
 
     for room_id in joins:
-        if room_id in my_receipts_by_room:
-            last_unread_event_id = my_receipts_by_room[room_id]
-
-            notifs = await (
-                store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, user_id, last_unread_event_id
-                )
+        notifs = await (
+            store.get_unread_event_push_actions_by_room_for_user(
+                room_id,
+                user_id,
             )
-            if notifs.notify_count == 0:
-                continue
+        )
+        if notifs.notify_count == 0:
+            continue
 
-            if group_by_room:
-                # return one badge count per conversation
-                badge += 1
-            else:
-                # increment the badge count by the number of unread messages in the room
-                badge += notifs.notify_count
+        if group_by_room:
+            # return one badge count per conversation
+            badge += 1
+        else:
+            # increment the badge count by the number of unread messages in the room
+            badge += notifs.notify_count
     return badge
 
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d0cc657b44..1e0ef44fc7 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -328,7 +328,7 @@ class PusherPool:
             return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusher_config)
+            pusher = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
@@ -346,23 +346,28 @@ class PusherPool:
             )
             return None
 
-        if not p:
+        if not pusher:
             return None
 
-        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
+        appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
 
-        byuser = self.pushers.setdefault(pusher_config.user_name, {})
+        byuser = self.pushers.setdefault(pusher.user_id, {})
         if appid_pushkey in byuser:
-            byuser[appid_pushkey].on_stop()
-        byuser[appid_pushkey] = p
+            previous_pusher = byuser[appid_pushkey]
+            previous_pusher.on_stop()
 
-        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
+            synapse_pushers.labels(
+                type(previous_pusher).__name__, previous_pusher.app_id
+            ).dec()
+        byuser[appid_pushkey] = pusher
+
+        synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusher_config.user_name
-        last_stream_ordering = pusher_config.last_stream_ordering
+        user_id = pusher.user_id
+        last_stream_ordering = pusher.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -372,9 +377,9 @@ class PusherPool:
             # risk missing push.
             have_notifs = True
 
-        p.on_started(have_notifs)
+        pusher.on_started(have_notifs)
 
-        return p
+        return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index aec040ee19..53aa7fa4c6 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.replication.http import (
     push,
     register,
     send_event,
+    state,
     streams,
 )
 
@@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource):
         streams.register_servlets(hs, self)
         account_data.register_servlets(hs, self)
         push.register_servlets(hs, self)
+        state.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index a4ae4040c3..acb0bd18f7 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -26,12 +26,13 @@ from twisted.web.server import Request
 
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
-from synapse.http.server import HttpServer, is_method_cancellable
+from synapse.http.server import HttpServer
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
@@ -196,7 +197,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 "ascii"
             )
 
-        @trace(opname="outgoing_replication_request")
+        @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
@@ -311,7 +312,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         url_args = list(self.PATH_ARGS)
         method = self.METHOD
 
-        if self.CACHE and is_method_cancellable(self._handle_request):
+        if self.CACHE and is_function_cancellable(self._handle_request):
             raise Exception(
                 f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
                 "is set. The cancellable flag would have no effect."
@@ -359,6 +360,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         # The `@cancellable` decorator may be applied to `_handle_request`. But we
         # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
         # so we have to set up the cancellable flag ourselves.
-        request.is_render_cancellable = is_method_cancellable(self._handle_request)
+        request.is_render_cancellable = is_function_cancellable(self._handle_request)
 
         return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index eed29cd597..d3abafed28 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -60,6 +60,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         {
             "max_stream_id": 32443,
         }
+
+    Responds with a 409 when a `PartialStateConflictError` is raised due to an event
+    context that needs to be recomputed due to the un-partial stating of a room.
     """
 
     NAME = "fed_send_events"
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c2b2588ea5..486f04723c 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -59,6 +59,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         { "stream_id": 12345, "event_id": "$abcdef..." }
 
+    Responds with a 409 when a `PartialStateConflictError` is raised due to an event
+    context that needs to be recomputed due to the un-partial stating of a room.
+
     The returned event ID may not match the sent event if it was deduplicated.
     """
 
diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py
new file mode 100644
index 0000000000..838b7584e5
--- /dev/null
+++ b/synapse/replication/http/state.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
+    """Recalculates the current state for a room, and persists it.
+
+    The API looks like:
+
+        POST /_synapse/replication/update_current_state/:room_id
+
+        {}
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "update_current_state"
+    PATH_ARGS = ("room_id",)
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._state_handler = hs.get_state_handler()
+        self._events_shard_config = hs.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
+
+    @staticmethod
+    async def _serialize_payload(room_id: str) -> JsonDict:  # type: ignore[override]
+        return {}
+
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        writer_instance = self._events_shard_config.get_instance(room_id)
+        if writer_instance != self._instance_name:
+            raise SynapseError(
+                400, "/update_current_state request was routed to the wrong worker"
+            )
+
+        await self._state_handler.update_current_state(room_id)
+
+        return 200, {}
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+    if hs.get_instance_name() in hs.config.worker.writers.events:
+        ReplicationUpdateCurrentStateRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
deleted file mode 100644
index 7644146dba..0000000000
--- a/synapse/replication/slave/storage/_base.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-        if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen: Optional[
-                MultiWriterIdGenerator
-            ] = MultiWriterIdGenerator(
-                db_conn,
-                database,
-                stream_name="caches",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    (
-                        "cache_invalidation_stream_by_instance",
-                        "instance_name",
-                        "stream_id",
-                    )
-                ],
-                sequence_name="cache_invalidation_stream_seq",
-                writers=[],
-            )
-        else:
-            self._cache_id_gen = None
-
-        self.hs = hs
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
deleted file mode 100644
index ee74ee7d85..0000000000
--- a/synapse/replication/slave/storage/account_data.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.account_data import AccountDataWorkerStore
-from synapse.storage.databases.main.tags import TagsWorkerStore
-
-
-class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
deleted file mode 100644
index 29f50c0add..0000000000
--- a/synapse/replication/slave/storage/appservice.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.appservice import (
-    ApplicationServiceTransactionWorkerStore,
-    ApplicationServiceWorkerStore,
-)
-
-
-class SlavedApplicationServiceStore(
-    ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
-):
-    pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
deleted file mode 100644
index e940751084..0000000000
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-
-
-class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 30717c2bd0..6fcade510a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -14,18 +14,16 @@
 
 from typing import TYPE_CHECKING, Any, Iterable
 
-from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
-from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
+class SlavedDeviceStore(DeviceWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
deleted file mode 100644
index 71fde0c96c..0000000000
--- a/synapse/replication/slave/storage/directory.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.directory import DirectoryWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a72dad7464..fe47778cb1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -29,8 +29,6 @@ from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -56,7 +54,6 @@ class SlavedEventStore(
     EventsWorkerStore,
     UserErasureWorkerStore,
     RelationsWorkerStore,
-    BaseSlavedStore,
 ):
     def __init__(
         self,
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 4d185e2b56..c52679cd60 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,16 +14,15 @@
 
 from typing import TYPE_CHECKING
 
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.filtering import FilteringStore
 
-from ._base import BaseSlavedStore
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedFilteringStore(BaseSlavedStore):
+class SlavedFilteringStore(SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
deleted file mode 100644
index 99f4a22642..0000000000
--- a/synapse/replication/slave/storage/profile.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.profile import ProfileWorkerStore
-
-
-class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 52ee3f7e58..5e65eaf1e0 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -31,6 +31,5 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index de642bba71..44ed20e424 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -18,14 +18,13 @@ from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.pusher import PusherWorkerStore
 
-from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
deleted file mode 100644
index 5dae35a960..0000000000
--- a/synapse/replication/slave/storage/registration.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.registration import RegistrationWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
 from synapse.federation import send_queue
 from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+                # If this event is a join, make a note of it so we have an accurate
+                # cross-worker room rate limit.
+                # TODO: Erik said we should exclude rows that came from ex_outliers
+                #  here, but I don't see how we can determine that. I guess we could
+                #  add a flag to row.data?
+                if (
+                    row.data.type == EventTypes.Member
+                    and row.data.membership == Membership.JOIN
+                    and not row.data.outlier
+                ):
+                    # TODO retrieve the previous state, and exclude join -> join transitions
+                    self.notifier.notify_user_joined_room(
+                        row.data.event_id, row.data.room_id
+                    )
+
         await self._presence_handler.process_replication_rows(
             stream_name, instance_name, token, rows
         )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
     relates_to: Optional[str]
     membership: Optional[str]
     rejected: bool
+    outlier: bool
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
index b751359bdf..bd4f7cea97 100644
--- a/synapse/res/templates/account_previously_renewed.html
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+    Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index e8c0f52f05..57b319f375 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+    Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
index cc4ab07e09..71f2215b7a 100644
--- a/synapse/res/templates/add_threepid.html
+++ b/synapse/res/templates/add_threepid.html
@@ -1,9 +1,14 @@
-<html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Request to add an email address to your Matrix account</title>
+</head>
 <body>
     <p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p>
-
     <a href="{{ link }}">{{ link }}</a>
-
     <p>If this was not you, you can safely ignore this email. Thank you.</p>
 </body>
 </html>
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
index 441d11c846..bd627ee9ce 100644
--- a/synapse/res/templates/add_threepid_failure.html
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -1,8 +1,13 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Request failed</title>
+</head>
 <body>
-<p>The request failed for the following reason: {{ failure_reason }}.</p>
-
-<p>No changes have been made to your account.</p>
+    <p>The request failed for the following reason: {{ failure_reason }}.</p>
+    <p>No changes have been made to your account.</p>
 </body>
 </html>
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
index fbd6e4018f..49170c138e 100644
--- a/synapse/res/templates/add_threepid_success.html
+++ b/synapse/res/templates/add_threepid_success.html
@@ -1,6 +1,12 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your email has now been validated</title>
+</head>
 <body>
-<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+    <p>Your email has now been validated, please return to your client. You may now close this window.</p>
 </body>
-</html>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html
index baf4633142..2d6ac44a0e 100644
--- a/synapse/res/templates/auth_success.html
+++ b/synapse/res/templates/auth_success.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 <script>
 if (window.onAuthDone) {
diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html
index 6bd2b98364..2c7c384fe3 100644
--- a/synapse/res/templates/invalid_token.html
+++ b/synapse/res/templates/invalid_token.html
@@ -1 +1,12 @@
-<html><body>Invalid renewal token.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Invalid renewal token.</title>
+</head>
+<body>
+    Invalid renewal token.
+</body>
+</html>
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index d87311f659..865f9f7ada 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -1,6 +1,8 @@
 <!doctype html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include 'mail.css' without context %}
             {% include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 27d4182790..9dba0c0253 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -1,6 +1,8 @@
 <!doctype html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {%- include 'mail.css' without context %}
             {%- include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html
index a197bf872c..a8bdce357b 100644
--- a/synapse/res/templates/password_reset.html
+++ b/synapse/res/templates/password_reset.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+    <head>
+        <title>Password reset</title>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    </head>
 <body>
     <p>A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:</p>
 
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
index def4b5162b..2e3fd2ec1e 100644
--- a/synapse/res/templates/password_reset_confirmation.html
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Password reset confirmation</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <!--Use a hidden form to resubmit the information necessary to reset the password-->
 <form method="post">
diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html
index 9e3c4446e3..2d59c463f0 100644
--- a/synapse/res/templates/password_reset_failure.html
+++ b/synapse/res/templates/password_reset_failure.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Password reset failure</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>The request failed for the following reason: {{ failure_reason }}.</p>
 
diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html
index 7324d66d1e..5165bd1fa2 100644
--- a/synapse/res/templates/password_reset_success.html
+++ b/synapse/res/templates/password_reset_success.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Your email has now been validated, please return to your client to reset your password. You may now close this window.</p>
 </body>
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index b3db06ef97..615d3239c6 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <script src="https://www.recaptcha.net/recaptcha/api.js"
     async defer></script>
 <script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html
index 16730a527f..20e831ff4a 100644
--- a/synapse/res/templates/registration.html
+++ b/synapse/res/templates/registration.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+<head>
+    <title>Registration</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
     <p>You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:</p>
 
diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html
index 2833d79c37..a6ed22bc90 100644
--- a/synapse/res/templates/registration_failure.html
+++ b/synapse/res/templates/registration_failure.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Validation failed for the following reason: {{ failure_reason }}.</p>
 </body>
diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html
index fbd6e4018f..d51d5549d8 100644
--- a/synapse/res/templates/registration_success.html
+++ b/synapse/res/templates/registration_success.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Your email has now been validated</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Your email has now been validated, please return to your client. You may now close this window.</p>
 </body>
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
index 4577ce1702..59a98f564c 100644
--- a/synapse/res/templates/registration_token.html
+++ b/synapse/res/templates/registration_token.html
@@ -1,8 +1,8 @@
-<html>
+<html lang="en">
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 </head>
 <body>
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index c3e4deed93..075f801cec 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -3,8 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>SSO account deactivated</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
-        <style type="text/css">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">        <style type="text/css">
             {% include "sso.css" without context %}
         </style>
     </head>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 1ba850369a..2d1db386e1 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -3,7 +3,8 @@
   <head>
     <title>Create your account</title>
     <meta charset="utf-8">
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <script type="text/javascript">
       let wasKeyboard = false;
       document.addEventListener("mousedown", function() { wasKeyboard = false; });
@@ -138,7 +139,7 @@
         <div class="username_input" id="username_input">
           <label for="field-username">Username (required)</label>
           <div class="prefix">@</div>
-          <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
+          <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus autocorrect="off" autocapitalize="none">
           <div class="postfix">:{{ server_name }}</div>
         </div>
         <output for="username_input" id="field-username-output"></output>
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index da579ffe69..94403fc3ce 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication failed</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index f9d0456f0a..aa1c974a6b 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Confirm it's you</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 1ed3967e87..4898af6011 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication successful</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 472309c350..19992ff2ad 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication failed</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 53b82db84e..56fabfa3d2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -1,6 +1,8 @@
 <!DOCTYPE html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <meta charset="UTF-8">
         <title>Choose identity provider</title>
         <style type="text/css">
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
index 68c8b9f33a..523f64c4fc 100644
--- a/synapse/res/templates/sso_new_user_consent.html
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -3,7 +3,8 @@
 <head>
     <meta charset="UTF-8">
     <title>Agree to terms and conditions</title>
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <style type="text/css">
       {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 1b01471ac8..1049a9bd92 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -3,7 +3,8 @@
 <head>
     <meta charset="UTF-8">
     <title>Continue to your account</title>
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <style type="text/css">
       {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
index 369ff446d2..2081d990ab 100644
--- a/synapse/res/templates/terms.html
+++ b/synapse/res/templates/terms.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 </head>
 <body>
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 399b205aaf..b467a61dfb 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -19,7 +19,7 @@ from typing import Iterable, Pattern
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
 from synapse.http.site import SynapseRequest
-from synapse.types import UserID
+from synapse.types import Requester
 
 
 def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
@@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None
         AuthError if the requester is not a server admin
     """
     requester = await auth.get_user_by_req(request)
-    await assert_user_is_admin(auth, requester.user)
+    await assert_user_is_admin(auth, requester)
 
 
-async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, requester: Requester) -> None:
     """Verify that the given user is an admin user
 
     Args:
         auth: Auth singleton
-        user_id: user to check
+        requester: The user making the request, according to the access token.
 
     Raises:
         AuthError if the user is not a server admin
     """
-    is_admin = await auth.is_server_admin(user_id)
+    is_admin = await auth.is_server_admin(requester)
     if not is_admin:
         raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 19d4a008e8..73470f09ae 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining room: %s", room_id)
 
@@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining media by user: %s", user_id)
 
@@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet):
         self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining media by ID: %s/%s", server_name, media_id)
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 9d953d58de..3d870629c4 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -75,7 +75,7 @@ class RoomRestV2Servlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
 
         requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_user_is_admin(self._auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -303,6 +303,7 @@ class RoomRestServlet(RestServlet):
 
         members = await self.store.get_users_in_room(room_id)
         ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+        ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
 
         return HTTPStatus.OK, ret
 
@@ -326,7 +327,7 @@ class RoomRestServlet(RestServlet):
         pagination_handler: "PaginationHandler",
     ) -> Tuple[int, JsonDict]:
         requester = await auth.get_user_by_req(request)
-        await assert_user_is_admin(auth, requester.user)
+        await assert_user_is_admin(auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -460,7 +461,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         assert request.args is not None
 
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -550,7 +551,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         content = parse_json_object_from_request(request, allow_empty_body=True)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
@@ -741,7 +742,7 @@ class RoomEventContextServlet(RestServlet):
         self, request: SynapseRequest, room_id: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         limit = parse_integer(request, "limit", default=10)
 
@@ -833,7 +834,7 @@ class BlockRoomRestServlet(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_user_is_admin(self._auth, requester)
 
         content = parse_json_object_from_request(request)
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f0614a2897..78ee9b6532 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         target_user = UserID.from_string(user_id)
         body = parse_json_object_from_request(request)
@@ -373,6 +373,7 @@ class UserRestServletV2(RestServlet):
                     if (
                         self.hs.config.email.email_enable_notifs
                         and self.hs.config.email.email_notif_for_new_users
+                        and medium == "email"
                     ):
                         await self.pusher_pool.add_pusher(
                             user_id=user_id,
@@ -574,10 +575,9 @@ class WhoisRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         target_user = UserID.from_string(user_id)
         requester = await self.auth.get_user_by_req(request)
-        auth_user = requester.user
 
-        if target_user != auth_user:
-            await assert_user_is_admin(self.auth, auth_user)
+        if target_user != requester.user:
+            await assert_user_is_admin(self.auth, requester)
 
         if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
@@ -600,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self, request: SynapseRequest, target_user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         if not self.is_mine(UserID.from_string(target_user_id)):
             raise SynapseError(
@@ -692,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet):
         This needs user to have administrator access in Synapse.
         """
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         UserID.from_string(target_user_id)
 
@@ -806,7 +806,7 @@ class UserAdminServlet(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         auth_user = requester.user
 
         target_user = UserID.from_string(user_id)
@@ -920,7 +920,7 @@ class UserTokenRestServlet(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         auth_user = requester.user
 
         if not self.is_mine_id(user_id):
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index bdc4a9c068..1f9a8ccc23 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 import logging
 import random
-from http import HTTPStatus
 from typing import TYPE_CHECKING, Optional, Tuple
 from urllib.parse import urlparse
 
+from pydantic import StrictBool, StrictStr, constr
+
 from twisted.web.server import Request
 
 from synapse.api.constants import LoginType
@@ -28,18 +29,20 @@ from synapse.api.errors import (
     SynapseError,
     ThreepidValidationError,
 )
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_and_validate_json_object_from_request,
     parse_json_object_from_request,
     parse_string,
 )
 from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
+from synapse.rest.models import RequestBodyModel
 from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -64,7 +67,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         self.config = hs.config
         self.identity_handler = hs.get_identity_handler()
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -73,41 +76,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User password resets have been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User password resets have been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based password resets have been disabled on this server"
             )
 
-        body = parse_json_object_from_request(request)
-
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-
-        # Extract params from body
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database. This allows the user to reset his password without having to
-        # know the exact spelling (eg. upper and lower case) of address in the database.
-        # Stored in the database "foo@bar.com"
-        # User requests with "FOO@bar.com" would raise a Not Found error
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
         # The email will be sent to the stored address.
@@ -115,7 +101,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         # an email address which is controlled by the attacker but which, after
         # canonicalisation, matches the one in our database.
         existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
-            "email", email
+            "email", body.email
         )
 
         if existing_user_id is None:
@@ -129,35 +115,20 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send password reset emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_password_reset_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
-
+        # Send password reset emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            body.email,
+            body.client_secret,
+            body.send_attempt,
+            self.mailer.send_password_reset_mail,
+            body.next_link,
+        )
         threepid_send_requests.labels(type="email", reason="password_reset").observe(
-            send_attempt
+            body.send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class PasswordRestServlet(RestServlet):
@@ -172,16 +143,23 @@ class PasswordRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self._set_password_handler = hs.get_set_password_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        logout_devices: StrictBool = True
+        if TYPE_CHECKING:
+            # workaround for https://github.com/samuelcolvin/pydantic/issues/156
+            new_password: Optional[StrictStr] = None
+        else:
+            new_password: Optional[constr(max_length=512, strict=True)] = None
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the new password provided to us.
-        new_password = body.pop("new_password", None)
+        new_password = body.new_password
         if new_password is not None:
-            if not isinstance(new_password, str) or len(new_password) > 512:
-                raise SynapseError(400, "Invalid password")
             self.password_policy_handler.validate_password(new_password)
 
         # there are two possibilities here. Either the user does not have an
@@ -201,7 +179,7 @@ class PasswordRestServlet(RestServlet):
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
                     requester,
                     request,
-                    body,
+                    body.dict(exclude_unset=True),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -224,7 +202,7 @@ class PasswordRestServlet(RestServlet):
                 result, params, session_id = await self.auth_handler.check_ui_auth(
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
-                    body,
+                    body.dict(exclude_unset=True),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -299,37 +277,33 @@ class DeactivateAccountRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        id_server: Optional[StrictStr] = None
+        # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297
+        erase: StrictBool = False
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
-        erase = body.get("erase", False)
-        if not isinstance(erase, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'erase' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         requester = await self.auth.get_user_by_req(request)
 
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase, requester
+                requester.user.to_string(), body.erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(exclude_unset=True),
             "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(),
-            erase,
-            requester,
-            id_server=body.get("id_server"),
+            requester.user.to_string(), body.erase, requester, id_server=body.id_server
         )
         if result:
             id_server_unbind_result = "success"
@@ -349,7 +323,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.store = self.hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -358,34 +332,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
+            )
             raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+                400,
+                "Adding an email to your account is disabled on this server",
             )
 
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database.
-        # This ensures that the validation email is sent to the canonicalised address
-        # as it will later be entered into the database.
-        # Otherwise the email will be sent to "FOO@bar.com" and stored as
-        # "foo@bar.com" in database.
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if not await check_3pid_allowed(self.hs, "email", email):
+        if not await check_3pid_allowed(self.hs, "email", body.email):
             raise SynapseError(
                 403,
                 "Your email domain is not authorized on this server",
@@ -393,14 +353,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
-        existing_user_id = await self.store.get_user_id_by_threepid("email", email)
+        existing_user_id = await self.store.get_user_id_by_threepid("email", body.email)
 
         if existing_user_id is not None:
             if self.config.server.request_token_inhibit_3pid_errors:
@@ -413,35 +373,21 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send threepid validation emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_add_threepid_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send threepid validation emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            body.email,
+            body.client_secret,
+            body.send_attempt,
+            self.mailer.send_add_threepid_mail,
+            body.next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="add_threepid").observe(
-            send_attempt
+            body.send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -534,25 +480,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         self.config = hs.config
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_add_threepid_template_failure_html
             )
 
     async def on_GET(self, request: Request) -> None:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
-            raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
             )
-        elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
             raise SynapseError(
-                400,
-                "This homeserver is not validating threepids. Use an identity server "
-                "instead.",
+                400, "Adding an email to your account is disabled on this server"
             )
 
         sid = parse_string(request, "sid", required=True)
@@ -743,10 +682,12 @@ class ThreepidBindRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
-        assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+        assert_params_in_dict(
+            body, ["id_server", "sid", "id_access_token", "client_secret"]
+        )
         id_server = body["id_server"]
         sid = body["sid"]
-        id_access_token = body.get("id_access_token")  # optional
+        id_access_token = body["id_access_token"]
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index bfe985939b..f13970b898 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -15,11 +15,11 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomID
 
 from ._base import client_patterns
 
@@ -104,6 +104,13 @@ class RoomAccountDataServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
+        if not RoomID.is_valid(room_id):
+            raise SynapseError(
+                400,
+                f"{room_id} is not a valid room ID",
+                Codes.INVALID_PARAM,
+            )
+
         body = parse_json_object_from_request(request)
 
         if account_data_type == "m.fully_read":
@@ -111,6 +118,7 @@ class RoomAccountDataServlet(RestServlet):
                 405,
                 "Cannot set m.fully_read through this API."
                 " Use /rooms/!roomId:server.name/read_markers",
+                Codes.BAD_JSON,
             )
 
         await self.handler.add_account_data_to_room(
@@ -130,6 +138,13 @@ class RoomAccountDataServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get account data for other users.")
 
+        if not RoomID.is_valid(room_id):
+            raise SynapseError(
+                400,
+                f"{room_id} is not a valid room ID",
+                Codes.INVALID_PARAM,
+            )
+
         event = await self.store.get_account_data_for_room_and_type(
             user_id, room_id, account_data_type
         )
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 6fab102437..ed6ce78d47 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         devices = await self.device_handler.get_devices_by_user(
             requester.user.to_string()
         )
+
+        # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+        # removed from each device. If it is enabled, then the field name will
+        # be replaced by the unstable identifier.
+        #
+        # When MSC3852 is accepted, this block of code can just be removed to
+        # expose "last_seen_user_agent" to clients.
+        for device in devices:
+            last_seen_user_agent = device["last_seen_user_agent"]
+            del device["last_seen_user_agent"]
+            if self._msc3852_enabled:
+                device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
         return 200, {"devices": devices}
 
 
@@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     async def on_GET(
         self, request: SynapseRequest, device_id: str
@@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet):
         )
         if device is None:
             raise NotFoundError("No device found")
+
+        # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+        # removed from each device. If it is enabled, then the field name will
+        # be replaced by the unstable identifier.
+        #
+        # When MSC3852 is accepted, this block of code can just be removed to
+        # expose "last_seen_user_agent" to clients.
+        last_seen_user_agent = device["last_seen_user_agent"]
+        del device["last_seen_user_agent"]
+        if self._msc3852_enabled:
+            device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
         return 200, device
 
     @interactive_auth_handler
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index e181a0dde2..bc1b18c92d 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -17,13 +17,7 @@ from typing import TYPE_CHECKING, Tuple
 
 from twisted.web.server import Request
 
-from synapse.api.errors import (
-    AuthError,
-    Codes,
-    InvalidClientCredentialsError,
-    NotFoundError,
-    SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -52,6 +46,8 @@ class ClientDirectoryServer(RestServlet):
         self.auth = hs.get_auth()
 
     async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+        if not RoomAlias.is_valid(room_alias):
+            raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM)
         room_alias_obj = RoomAlias.from_string(room_alias)
 
         res = await self.directory_handler.get_association(room_alias_obj)
@@ -61,6 +57,8 @@ class ClientDirectoryServer(RestServlet):
     async def on_PUT(
         self, request: SynapseRequest, room_alias: str
     ) -> Tuple[int, JsonDict]:
+        if not RoomAlias.is_valid(room_alias):
+            raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM)
         room_alias_obj = RoomAlias.from_string(room_alias)
 
         content = parse_json_object_from_request(request)
@@ -95,31 +93,30 @@ class ClientDirectoryServer(RestServlet):
     async def on_DELETE(
         self, request: SynapseRequest, room_alias: str
     ) -> Tuple[int, JsonDict]:
+        if not RoomAlias.is_valid(room_alias):
+            raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM)
         room_alias_obj = RoomAlias.from_string(room_alias)
+        requester = await self.auth.get_user_by_req(request)
 
-        try:
-            service = self.auth.get_appservice_by_req(request)
+        if requester.app_service:
             await self.directory_handler.delete_appservice_association(
-                service, room_alias_obj
+                requester.app_service, room_alias_obj
             )
+
             logger.info(
                 "Application service at %s deleted alias %s",
-                service.url,
+                requester.app_service.url,
                 room_alias_obj.to_string(),
             )
-            return 200, {}
-        except InvalidClientCredentialsError:
-            # fallback to default user behaviour if they aren't an AS
-            pass
 
-        requester = await self.auth.get_user_by_req(request)
-        user = requester.user
+        else:
+            await self.directory_handler.delete_association(requester, room_alias_obj)
 
-        await self.directory_handler.delete_association(requester, room_alias_obj)
-
-        logger.info(
-            "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
-        )
+            logger.info(
+                "User %s deleted alias %s",
+                requester.user.to_string(),
+                room_alias_obj.to_string(),
+            )
 
         return 200, {}
 
@@ -154,17 +151,6 @@ class ClientDirectoryListServer(RestServlet):
 
         return 200, {}
 
-    async def on_DELETE(
-        self, request: SynapseRequest, room_id: str
-    ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-
-        await self.directory_handler.edit_published_room_list(
-            requester, room_id, "private"
-        )
-
-        return 200, {}
-
 
 class ClientAppserviceDirectoryListServer(RestServlet):
     PATTERNS = client_patterns(
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ce806e3c11..a395694fa5 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import log_kv, set_tag
 from synapse.types import JsonDict, StreamToken
 
 from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,6 @@ class KeyUploadServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = hs.get_device_handler()
 
-    @trace(opname="upload_keys")
     async def on_POST(
         self, request: SynapseRequest, device_id: Optional[str]
     ) -> Tuple[int, JsonDict]:
@@ -208,7 +207,9 @@ class KeyChangesServlet(RestServlet):
 
         # We want to enforce they do pass us one, but we ignore it and return
         # changes after the "to" as well as before.
-        set_tag("to", parse_string(request, "to"))
+        #
+        # XXX This does not enforce that "to" is passed.
+        set_tag("to", str(parse_string(request, "to")))
 
         from_token = await StreamToken.from_string(self.store, from_token_string)
 
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index cf4196ac0a..0437c87d8d 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,7 @@ from typing import (
 
 from typing_extensions import TypedDict
 
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
@@ -172,7 +172,13 @@ class LoginRestServlet(RestServlet):
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
-                appservice = self.auth.get_appservice_by_req(request)
+                requester = await self.auth.get_user_by_req(request)
+                appservice = requester.app_service
+
+                if appservice is None:
+                    raise InvalidClientTokenError(
+                        "This login method is only valid for application services"
+                    )
 
                 if appservice.is_rate_limited():
                     await self._address_ratelimiter.ratelimit(
@@ -420,17 +426,31 @@ class LoginRestServlet(RestServlet):
                 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
             )
 
-        import jwt
+        from authlib.jose import JsonWebToken, JWTClaims
+        from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
+
+        jwt = JsonWebToken([self.jwt_algorithm])
+        claim_options = {}
+        if self.jwt_issuer is not None:
+            claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
+        if self.jwt_audiences is not None:
+            claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
 
         try:
-            payload = jwt.decode(
+            claims = jwt.decode(
                 token,
-                self.jwt_secret,
-                algorithms=[self.jwt_algorithm],
-                issuer=self.jwt_issuer,
-                audience=self.jwt_audiences,
+                key=self.jwt_secret,
+                claims_cls=JWTClaims,
+                claims_options=claim_options,
+            )
+        except BadSignatureError:
+            # We handle this case separately to provide a better error message
+            raise LoginError(
+                403,
+                "JWT validation failed: Signature verification failed",
+                errcode=Codes.FORBIDDEN,
             )
-        except jwt.PyJWTError as e:
+        except JoseError as e:
             # A JWT error occurred, return some info back to the client.
             raise LoginError(
                 403,
@@ -438,7 +458,23 @@ class LoginRestServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        user = payload.get(self.jwt_subject_claim, None)
+        try:
+            claims.validate(leeway=120)  # allows 2 min of clock skew
+
+            # Enforce the old behavior which is rolled out in productive
+            # servers: if the JWT contains an 'aud' claim but none is
+            # configured, the login attempt will fail
+            if claims.get("aud") is not None:
+                if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
+                    raise InvalidClaimError("aud")
+        except JoseError as e:
+            raise LoginError(
+                403,
+                "JWT validation failed: %s" % (str(e),),
+                errcode=Codes.FORBIDDEN,
+            )
+
+        user = claims.get(self.jwt_subject_claim, None)
         if user is None:
             raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
new file mode 100644
index 0000000000..3150602997
--- /dev/null
+++ b/synapse/rest/client/models.py
@@ -0,0 +1,69 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Dict, Optional
+
+from pydantic import Extra, StrictInt, StrictStr, constr, validator
+
+from synapse.rest.models import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+    """
+    Data used during user-interactive authentication.
+
+    (The name "Authentication Data" is taken directly from the spec.)
+
+    Additional keys will be present, depending on the `type` field. Use `.dict()` to
+    access them.
+    """
+
+    class Config:
+        extra = Extra.allow
+
+    session: Optional[StrictStr] = None
+    type: Optional[StrictStr] = None
+
+
+class EmailRequestTokenBody(RequestBodyModel):
+    if TYPE_CHECKING:
+        client_secret: StrictStr
+    else:
+        # See also assert_valid_client_secret()
+        client_secret: constr(
+            regex="[0-9a-zA-Z.=_-]",  # noqa: F722
+            min_length=0,
+            max_length=255,
+            strict=True,
+        )
+    email: StrictStr
+    id_server: Optional[StrictStr]
+    id_access_token: Optional[StrictStr]
+    next_link: Optional[StrictStr]
+    send_attempt: StrictInt
+
+    @validator("id_access_token", always=True)
+    def token_required_for_identity_server(
+        cls, token: Optional[str], values: Dict[str, object]
+    ) -> Optional[str]:
+        if values.get("id_server") is not None and token is None:
+            raise ValueError("id_access_token is required if an id_server is supplied.")
+        return token
+
+    # Canonicalise the email address. The addresses are all stored canonicalised
+    # in the database. This allows the user to reset his password without having to
+    # know the exact spelling (eg. upper and lower case) of address in the database.
+    # Without this, an email stored in the database as "foo@bar.com" would cause
+    # user requests for "FOO@bar.com" to raise a Not Found error.
+    _email_validator = validator("email", allow_reuse=True)(validate_email)
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 24bc7c9095..61268e3af1 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,11 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+            user_id,
+            [
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ],
         )
 
         notif_event_ids = [pa.event_id for pa in push_actions]
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index c684636c0a..e69fa0829d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 """ This module contains REST servlets to do with profile: /profile/<paths> """
-
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, SynapseError
@@ -45,8 +45,12 @@ class ProfileDisplaynameRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             requester_user = requester.user
 
-        user = UserID.from_string(user_id)
+        if not UserID.is_valid(user_id):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+            )
 
+        user = UserID.from_string(user_id)
         await self.profile_handler.check_profile_query_allowed(user, requester_user)
 
         displayname = await self.profile_handler.get_displayname(user)
@@ -62,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user = UserID.from_string(user_id)
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
 
         content = parse_json_object_from_request(request)
 
@@ -98,8 +102,12 @@ class ProfileAvatarURLRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             requester_user = requester.user
 
-        user = UserID.from_string(user_id)
+        if not UserID.is_valid(user_id):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+            )
 
+        user = UserID.from_string(user_id)
         await self.profile_handler.check_profile_query_allowed(user, requester_user)
 
         avatar_url = await self.profile_handler.get_avatar_url(user)
@@ -115,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
 
         content = parse_json_object_from_request(request)
         try:
@@ -150,8 +158,12 @@ class ProfileRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             requester_user = requester.user
 
-        user = UserID.from_string(user_id)
+        if not UserID.is_valid(user_id):
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+            )
 
+        user = UserID.from_string(user_id)
         await self.profile_handler.check_profile_query_allowed(user, requester_user)
 
         displayname = await self.profile_handler.get_displayname(user)
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index d6487c31dd..9a1f10f4be 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,17 +16,17 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import HttpServer, respond_with_html_bytes
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
-    parse_string,
 )
 from synapse.http.site import SynapseRequest
 from synapse.push import PusherConfigException
 from synapse.rest.client._base import client_patterns
+from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -132,48 +133,21 @@ class PushersSetRestServlet(RestServlet):
         return 200, {}
 
 
-class PushersRemoveRestServlet(RestServlet):
+class LegacyPushersRemoveRestServlet(UnsubscribeResource, RestServlet):
     """
-    To allow pusher to be delete by clicking a link (ie. GET request)
+    A servlet to handle legacy "email unsubscribe" links, forwarding requests to the ``UnsubscribeResource``
+
+    This should be kept for some time, so unsubscribe links in past emails stay valid.
     """
 
-    PATTERNS = client_patterns("/pushers/remove$", v1=True)
-    SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
-
-    def __init__(self, hs: "HomeServer"):
-        super().__init__()
-        self.hs = hs
-        self.notifier = hs.get_notifier()
-        self.auth = hs.get_auth()
-        self.pusher_pool = self.hs.get_pusherpool()
+    PATTERNS = client_patterns("/pushers/remove$", releases=[], v1=False, unstable=True)
 
     async def on_GET(self, request: SynapseRequest) -> None:
-        requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
-        user = requester.user
-
-        app_id = parse_string(request, "app_id", required=True)
-        pushkey = parse_string(request, "pushkey", required=True)
-
-        try:
-            await self.pusher_pool.remove_pusher(
-                app_id=app_id, pushkey=pushkey, user_id=user.to_string()
-            )
-        except StoreError as se:
-            if se.code != 404:
-                # This is fine: they're already unsubscribed
-                raise
-
-        self.notifier.on_new_replication_data()
-
-        respond_with_html_bytes(
-            request,
-            200,
-            PushersRemoveRestServlet.SUCCESS_HTML,
-        )
-        return None
+        # Forward the request to the UnsubscribeResource
+        await self._async_render(request)
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     PushersRestServlet(hs).register(http_server)
     PushersSetRestServlet(hs).register(http_server)
-    PushersRemoveRestServlet(hs).register(http_server)
+    LegacyPushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..5e53096539 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,6 +40,12 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.FULLY_READ,
+            ReceiptTypes.READ_PRIVATE,
+        }
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -49,13 +55,7 @@ class ReadMarkerRestServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        valid_receipt_types = {
-            ReceiptTypes.READ,
-            ReceiptTypes.FULLY_READ,
-            ReceiptTypes.READ_PRIVATE,
-        }
-
-        unrecognized_types = set(body.keys()) - valid_receipt_types
+        unrecognized_types = set(body.keys()) - self._known_receipt_types
         if unrecognized_types:
             # It's fine if there are unrecognized receipt types, but let's log
             # it to help debug clients that have typoed the receipt type.
@@ -65,31 +65,25 @@ class ReadMarkerRestServlet(RestServlet):
             # types.
             logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
 
-        read_event_id = body.get(ReceiptTypes.READ, None)
-        if read_event_id:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ,
-                user_id=requester.user.to_string(),
-                event_id=read_event_id,
-            )
-
-        read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
-        if read_private_event_id and self.config.experimental.msc2285_enabled:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ_PRIVATE,
-                user_id=requester.user.to_string(),
-                event_id=read_private_event_id,
-            )
-
-        read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
-        if read_marker_event_id:
-            await self.read_marker_handler.received_client_read_marker(
-                room_id,
-                user_id=requester.user.to_string(),
-                event_id=read_marker_event_id,
-            )
+        for receipt_type in self._known_receipt_types:
+            event_id = body.get(receipt_type, None)
+            # TODO Add validation to reject non-string event IDs.
+            if not event_id:
+                continue
+
+            if receipt_type == ReceiptTypes.FULLY_READ:
+                await self.read_marker_handler.received_client_read_marker(
+                    room_id,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
+            else:
+                await self.receipts_handler.received_client_receipt(
+                    room_id,
+                    receipt_type,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
 
         return 200, {}
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..5b7fad7402 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.FULLY_READ,
+        }
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
-            ReceiptTypes.READ,
-            ReceiptTypes.READ_PRIVATE,
-            ReceiptTypes.FULLY_READ,
-        ]:
+        if receipt_type not in self._known_receipt_types:
             raise SynapseError(
                 400,
-                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+                f"Receipt type must be {', '.join(self._known_receipt_types)}",
             )
-        elif (
-            not self.hs.config.experimental.msc2285_enabled
-            and receipt_type != ReceiptTypes.READ
-        ):
-            raise SynapseError(400, "Receipt type must be 'm.read'")
 
         parse_json_object_from_request(request, allow_empty_body=False)
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index e8e51a9c66..20bab20c8f 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -31,9 +31,8 @@ from synapse.api.errors import (
 )
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
@@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.config = hs.config
 
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if (
-                self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
-            ):
-                logger.warning(
-                    "Email registration has been disabled due to lack of email config"
-                )
+        if not self.hs.config.email.can_verify_email:
+            logger.warning(
+                "Email registration has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration has been disabled on this server"
             )
@@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send registration emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_registration_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send registration emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            email,
+            client_secret,
+            send_attempt,
+            self.mailer.send_registration_mail,
+            next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="register").observe(
             send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_registration_template_failure_html
             )
@@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet):
             raise SynapseError(
                 400, "This medium is currently not supported for registration"
             )
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User registration via email has been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User registration via email has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration is disabled on this server"
             )
@@ -325,7 +306,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.ratelimiter = FederationRateLimiter(
             hs.get_clock(),
-            FederationRateLimitConfig(
+            FederationRatelimitSettings(
                 # Time window of 2s
                 window_size=2000,
                 # Artificially delay requests if rate > sleep_limit/window_size
@@ -484,9 +465,6 @@ class RegisterRestServlet(RestServlet):
                     "Appservice token must be provided when using a type of m.login.application_service",
                 )
 
-            # Verify the AS
-            self.auth.get_appservice_by_req(request)
-
             # Set the desired user according to the AS API (which uses the
             # 'user' key not 'username'). Since this is a new addition, we'll
             # fallback to 'username' if they gave one.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 3cae6d2b55..ce97080013 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -43,6 +43,7 @@ class RelationPaginationServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
         self._relations_handler = hs.get_relations_handler()
+        self._msc3715_enabled = hs.config.experimental.msc3715_enabled
 
     async def on_GET(
         self,
@@ -55,9 +56,15 @@ class RelationPaginationServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         limit = parse_integer(request, "limit", default=5)
-        direction = parse_string(
-            request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
-        )
+        if self._msc3715_enabled:
+            direction = parse_string(
+                request,
+                "org.matrix.msc3715.dir",
+                default="b",
+                allowed_values=["f", "b"],
+            )
+        else:
+            direction = "b"
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index a26e976492..0bca012535 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,13 @@
 """ This module contains REST servlets to do with rooms: /rooms/<paths> """
 import logging
 import re
+from enum import Enum
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
 from urllib import parse as urlparse
 
+from prometheus_client.core import Histogram
+
 from twisted.web.server import Request
 
 from synapse import event_auth
@@ -34,7 +38,7 @@ from synapse.api.errors import (
 )
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer, cancellable
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
@@ -46,6 +50,7 @@ from synapse.http.servlet import (
     parse_strings_from_args,
 )
 from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag
 from synapse.rest.client._base import client_patterns
 from synapse.rest.client.transactions import HttpTransactionCache
@@ -53,6 +58,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
+from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
@@ -61,6 +67,70 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class _RoomSize(Enum):
+    """
+    Enum to differentiate sizes of rooms. This is a pretty good approximation
+    about how hard it will be to get events in the room. We could also look at
+    room "complexity".
+    """
+
+    # This doesn't necessarily mean the room is a DM, just that there is a DM
+    # amount of people there.
+    DM_SIZE = "direct_message_size"
+    SMALL = "small"
+    SUBSTANTIAL = "substantial"
+    LARGE = "large"
+
+    @staticmethod
+    def from_member_count(member_count: int) -> "_RoomSize":
+        if member_count <= 2:
+            return _RoomSize.DM_SIZE
+        elif member_count < 100:
+            return _RoomSize.SMALL
+        elif member_count < 1000:
+            return _RoomSize.SUBSTANTIAL
+        else:
+            return _RoomSize.LARGE
+
+
+# This is an extra metric on top of `synapse_http_server_response_time_seconds`
+# which times the same sort of thing but this one allows us to see values
+# greater than 10s. We use a separate dedicated histogram with its own buckets
+# so that we don't increase the cardinality of the general one because it's
+# multiplied across hundreds of servlets.
+messsages_response_timer = Histogram(
+    "synapse_room_message_list_rest_servlet_response_time_seconds",
+    "sec",
+    # We have a label for room size so we can try to see a more realistic
+    # picture of /messages response time for bigger rooms. We don't want the
+    # tiny rooms that can always respond fast skewing our results when we're trying
+    # to optimize the bigger cases.
+    ["room_size"],
+    buckets=(
+        0.005,
+        0.01,
+        0.025,
+        0.05,
+        0.1,
+        0.25,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        10.0,
+        20.0,
+        30.0,
+        60.0,
+        80.0,
+        100.0,
+        120.0,
+        150.0,
+        180.0,
+        "+Inf",
+    ),
+)
+
+
 class TransactionRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -165,7 +235,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
         msg_handler = self.message_handler
         data = await msg_handler.get_room_data(
-            user_id=requester.user.to_string(),
+            requester=requester,
             room_id=room_id,
             event_type=event_type,
             state_key=state_key,
@@ -510,7 +580,7 @@ class RoomMemberListRestServlet(RestServlet):
 
         events = await handler.get_state_events(
             room_id=room_id,
-            user_id=requester.user.to_string(),
+            requester=requester,
             at_token=at_token,
             state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
         )
@@ -556,6 +626,7 @@ class RoomMessageListRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._hs = hs
+        self.clock = hs.get_clock()
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -563,6 +634,18 @@ class RoomMessageListRestServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
+        processing_start_time = self.clock.time_msec()
+        # Fire off and hope that we get a result by the end.
+        #
+        # We're using the mypy type ignore comment because the `@cached`
+        # decorator on `get_number_joined_users_in_room` doesn't play well with
+        # the type system. Maybe in the future, it can use some ParamSpec
+        # wizardry to fix it up.
+        room_member_count_deferred = run_in_background(  # type: ignore[call-arg]
+            self.store.get_number_joined_users_in_room,
+            room_id,  # type: ignore[arg-type]
+        )
+
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = await PaginationConfig.from_request(
             self.store, request, default_limit=10
@@ -593,6 +676,12 @@ class RoomMessageListRestServlet(RestServlet):
             event_filter=event_filter,
         )
 
+        processing_end_time = self.clock.time_msec()
+        room_member_count = await make_deferred_yieldable(room_member_count_deferred)
+        messsages_response_timer.labels(
+            room_size=_RoomSize.from_member_count(room_member_count)
+        ).observe((processing_end_time - processing_start_time) / 1000)
+
         return 200, msgs
 
 
@@ -613,8 +702,7 @@ class RoomStateRestServlet(RestServlet):
         # Get all the current state for this room
         events = await self.message_handler.get_state_events(
             room_id=room_id,
-            user_id=requester.user.to_string(),
-            is_guest=requester.is_guest,
+            requester=requester,
         )
         return 200, events
 
@@ -672,7 +760,7 @@ class RoomEventServlet(RestServlet):
             == "true"
         )
         if include_unredacted_content and not await self.auth.is_server_admin(
-            requester.user
+            requester
         ):
             power_level_event = (
                 await self._storage_controllers.state.get_current_state_event(
@@ -860,7 +948,16 @@ class RoomMembershipRestServlet(TransactionRestServlet):
             # cheekily send invalid bodies.
             content = {}
 
-        if membership_action == "invite" and self._has_3pid_invite_keys(content):
+        if membership_action == "invite" and all(
+            key in content for key in ("medium", "address")
+        ):
+            if not all(key in content for key in ("id_server", "id_access_token")):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "`id_server` and `id_access_token` are required when doing 3pid invite",
+                    Codes.MISSING_PARAM,
+                )
+
             try:
                 await self.room_member_handler.do_3pid_invite(
                     room_id,
@@ -870,7 +967,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
                     content["id_server"],
                     requester,
                     txn_id,
-                    content.get("id_access_token"),
+                    content["id_access_token"],
                 )
             except ShadowBanError:
                 # Pretend the request succeeded.
@@ -907,12 +1004,6 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
         return 200, return_value
 
-    def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
-        for key in {"id_server", "medium", "address"}:
-            if key not in content:
-                return False
-        return True
-
     def on_PUT(
         self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
@@ -1177,7 +1268,7 @@ class TimestampLookupRestServlet(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self._auth.get_user_by_req(request)
-        await self._auth.check_user_in_room(room_id, requester.user.to_string())
+        await self._auth.check_user_in_room_or_world_readable(room_id, requester)
 
         timestamp = parse_integer(request, "ts", required=True)
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 37e39570f6..f7081f638e 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         if session_id:
             body = {"sessions": {session_id: body}}
@@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
         user_id = requester.user.to_string()
         version = parse_string(request, "version", required=True)
 
-        room_keys = await self.e2e_room_keys_handler.get_room_keys(
-            user_id, version, room_id, session_id
+        room_keys = cast(
+            JsonDict,
+            await self.e2e_room_keys_handler.get_room_keys(
+                user_id, version, room_id, session_id
+            ),
         )
 
         # Convert room_keys to the right format to return.
@@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         ret = await self.e2e_room_keys_handler.delete_room_keys(
             user_id, version, room_id, session_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 3322c8ef48..46a8b03829 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
 from synapse.http.server import HttpServer
 from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace
+from synapse.logging.opentracing import set_tag
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import JsonDict
 
@@ -43,7 +43,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         self.txns = HttpTransactionCache(hs)
         self.device_message_handler = hs.get_device_message_handler()
 
-    @trace(opname="sendToDevice")
     def on_PUT(
         self, request: SynapseRequest, message_type: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8bbf35148d..c2989765ce 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    @trace(opname="sync.encode_response")
+    @trace_with_opname("sync.encode_response")
     async def encode_response(
         self,
         time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
             ]
         }
 
-    @trace(opname="sync.encode_joined")
+    @trace_with_opname("sync.encode_joined")
     async def encode_joined(
         self,
         rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    @trace(opname="sync.encode_invited")
+    @trace_with_opname("sync.encode_invited")
     async def encode_invited(
         self,
         rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    @trace(opname="sync.encode_knocked")
+    @trace_with_opname("sync.encode_knocked")
     async def encode_knocked(
         self,
         rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
 
         return knocked
 
-    @trace(opname="sync.encode_archived")
+    @trace_with_opname("sync.encode_archived")
     async def encode_archived(
         self,
         rooms: List[ArchivedSyncResult],
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c1bd775fec..c516cda95d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -94,7 +94,9 @@ class VersionsRestServlet(RestServlet):
                     # Supports the busy presence state described in MSC3026.
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
                     # Supports receiving private read receipts as per MSC2285
-                    "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+                    "org.matrix.msc2285.stable": True,  # TODO: Remove when MSC2285 becomes a part of the spec
+                    # Supports filtering of /publicRooms by room type as per MSC3827
+                    "org.matrix.msc3827.stable": True,
                     # Adds support for importing historical messages as per MSC2716
                     "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
                     # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f597157581..7f8ad29566 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -135,13 +135,6 @@ class RemoteKey(DirectServeJsonResource):
 
         store_queries = []
         for server_name, key_ids in query.items():
-            if (
-                self.federation_domain_whitelist is not None
-                and server_name not in self.federation_domain_whitelist
-            ):
-                logger.debug("Federation denied with %s", server_name)
-                continue
-
             if not key_ids:
                 key_ids = (None,)
             for key_id in key_ids:
@@ -153,21 +146,28 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        # Note that the value is unused.
+        # Map server_name->key_id->int. Note that the value of the init is unused.
+        # XXX: why don't we just use a set?
         cache_misses: Dict[str, Dict[str, int]] = {}
         for (server_name, key_id, _), key_results in cached.items():
             results = [(result["ts_added_ms"], result) for result in key_results]
 
-            if not results and key_id is not None:
-                cache_misses.setdefault(server_name, {})[key_id] = 0
+            if key_id is None:
+                # all keys were requested. Just return what we have without worrying
+                # about validity
+                for _, result in results:
+                    # Cast to bytes since postgresql returns a memoryview.
+                    json_results.add(bytes(result["key_json"]))
                 continue
 
-            if key_id is not None:
+            miss = False
+            if not results:
+                miss = True
+            else:
                 ts_added_ms, most_recent_result = max(results)
                 ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
                 req_key = query.get(server_name, {}).get(key_id, {})
                 req_valid_until = req_key.get("minimum_valid_until_ts")
-                miss = False
                 if req_valid_until is not None:
                     if ts_valid_until_ms < req_valid_until:
                         logger.debug(
@@ -211,19 +211,20 @@ class RemoteKey(DirectServeJsonResource):
                         ts_valid_until_ms,
                         time_now_ms,
                     )
-
-                if miss:
-                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
-            else:
-                for _, result in results:
-                    # Cast to bytes since postgresql returns a memoryview.
-                    json_results.add(bytes(result["key_json"]))
+
+            if miss and query_remote_on_cache_miss:
+                # only bother attempting to fetch keys from servers on our whitelist
+                if (
+                    self.federation_domain_whitelist is None
+                    or server_name in self.federation_domain_whitelist
+                ):
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
 
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
-        if cache_misses and query_remote_on_cache_miss:
+        if cache_misses:
             await yieldable_gather_results(
                 lambda t: self.fetcher.get_keys(*t),
                 (
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index c35d42fab8..d30878f704 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -254,30 +254,32 @@ async def respond_with_responder(
         file_size: Size in bytes of the media. If not known it should be None
         upload_name: The name of the requested file, if any.
     """
-    if request._disconnected:
-        logger.warning(
-            "Not sending response to request %s, already disconnected.", request
-        )
-        return
-
     if not responder:
         respond_404(request)
         return
 
-    logger.debug("Responding to media request with responder %s", responder)
-    add_file_headers(request, media_type, file_size, upload_name)
-    try:
-        with responder:
+    # If we have a responder we *must* use it as a context manager.
+    with responder:
+        if request._disconnected:
+            logger.warning(
+                "Not sending response to request %s, already disconnected.", request
+            )
+            return
+
+        logger.debug("Responding to media request with responder %s", responder)
+        add_file_headers(request, media_type, file_size, upload_name)
+        try:
+
             await responder.write_to_consumer(request)
-    except Exception as e:
-        # The majority of the time this will be due to the client having gone
-        # away. Unfortunately, Twisted simply throws a generic exception at us
-        # in that case.
-        logger.warning("Failed to write to consumer: %s %s", type(e), e)
-
-        # Unregister the producer, if it has one, so Twisted doesn't complain
-        if request.producer:
-            request.unregisterProducer()
+        except Exception as e:
+            # The majority of the time this will be due to the client having gone
+            # away. Unfortunately, Twisted simply throws a generic exception at us
+            # in that case.
+            logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+            # Unregister the producer, if it has one, so Twisted doesn't complain
+            if request.producer:
+                request.unregisterProducer()
 
     finish_request(request)
 
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6180fa575e..048a042692 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,7 +15,11 @@
 import logging
 from typing import TYPE_CHECKING
 
-from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.server import (
+    DirectServeJsonResource,
+    set_corp_headers,
+    set_cors_headers,
+)
 from synapse.http.servlet import parse_boolean
 from synapse.http.site import SynapseRequest
 
@@ -38,6 +42,7 @@ class DownloadResource(DirectServeJsonResource):
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
         set_cors_headers(request)
+        set_corp_headers(request)
         request.setHeader(
             b"Content-Security-Policy",
             b"sandbox;"
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 7435fd9130..9dd3c8d4bb 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -64,7 +64,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
 # How often to run the background job to update the "recently accessed"
 # attribute of local and remote media.
 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000  # 1 minute
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 604f18bf52..a5c3de192f 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -36,6 +36,7 @@ from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
+import synapse
 from synapse.api.errors import NotFoundError
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
 from synapse.util import Clock
@@ -145,15 +146,17 @@ class MediaStorage:
                     f.flush()
                     f.close()
 
-                    spam = await self.spam_checker.check_media_file_for_spam(
+                    spam_check = await self.spam_checker.check_media_file_for_spam(
                         ReadableFileWrapper(self.clock, fname), file_info
                     )
-                    if spam:
+                    if spam_check != synapse.module_api.NOT_SPAM:
                         logger.info("Blocking media due to spam checker")
                         # Note that we'll delete the stored media, due to the
                         # try/except below. The media also won't be stored in
                         # the DB.
-                        raise SpamMediaException()
+                        # We currently ignore any additional field returned by
+                        # the spam-check API.
+                        raise SpamMediaException(errcode=spam_check[0])
 
                     for provider in self.storage_providers:
                         await provider.store_file(path, file_info)
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index ed8f21a483..516d0434f0 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -12,10 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import codecs
-import itertools
 import logging
 import re
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Union,
+)
 
 if TYPE_CHECKING:
     from lxml import etree
@@ -146,6 +155,70 @@ def decode_body(
     return etree.fromstring(body, parser)
 
 
+def _get_meta_tags(
+    tree: "etree.Element",
+    property: str,
+    prefix: str,
+    property_mapper: Optional[Callable[[str], Optional[str]]] = None,
+) -> Dict[str, Optional[str]]:
+    """
+    Search for meta tags prefixed with a particular string.
+
+    Args:
+        tree: The parsed HTML document.
+        property: The name of the property which contains the tag name, e.g.
+            "property" for Open Graph.
+        prefix: The prefix on the property to search for, e.g. "og" for Open Graph.
+        property_mapper: An optional callable to map the property to the Open Graph
+            form. Can return None for a key to ignore that key.
+
+    Returns:
+        A map of tag name to value.
+    """
+    results: Dict[str, Optional[str]] = {}
+    for tag in tree.xpath(
+        f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]"
+    ):
+        # if we've got more than 50 tags, someone is taking the piss
+        if len(results) >= 50:
+            logger.warning(
+                "Skipping parsing of Open Graph for page with too many '%s:' tags",
+                prefix,
+            )
+            return {}
+
+        key = tag.attrib[property]
+        if property_mapper:
+            key = property_mapper(key)
+            # None is a special value used to ignore a value.
+            if key is None:
+                continue
+
+        results[key] = tag.attrib["content"]
+
+    return results
+
+
+def _map_twitter_to_open_graph(key: str) -> Optional[str]:
+    """
+    Map a Twitter card property to the analogous Open Graph property.
+
+    Args:
+        key: The Twitter card property (starts with "twitter:").
+
+    Returns:
+        The Open Graph property (starts with "og:") or None to have this property
+        be ignored.
+    """
+    # Twitter card properties with no analogous Open Graph property.
+    if key == "twitter:card" or key == "twitter:creator":
+        return None
+    if key == "twitter:site":
+        return "og:site_name"
+    # Otherwise, swap twitter to og.
+    return "og" + key[7:]
+
+
 def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
     """
     Parse the HTML document into an Open Graph response.
@@ -160,10 +233,8 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
         The Open Graph response as a dictionary.
     """
 
-    # if we see any image URLs in the OG response, then spider them
-    # (although the client could choose to do this by asking for previews of those
-    # URLs to avoid DoSing the server)
-
+    # Search for Open Graph (og:) meta tags, e.g.:
+    #
     # "og:type"         : "video",
     # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
     # "og:site_name"    : "YouTube",
@@ -176,19 +247,11 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
     # "og:video:height" : "720",
     # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
 
-    og: Dict[str, Optional[str]] = {}
-    for tag in tree.xpath(
-        "//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]"
-    ):
-        # if we've got more than 50 tags, someone is taking the piss
-        if len(og) >= 50:
-            logger.warning("Skipping OG for page with too many 'og:' tags")
-            return {}
-
-        og[tag.attrib["property"]] = tag.attrib["content"]
-
-    # TODO: grab article: meta tags too, e.g.:
+    og = _get_meta_tags(tree, "property", "og")
 
+    # TODO: Search for properties specific to the different Open Graph types,
+    # such as article: meta tags, e.g.:
+    #
     # "article:publisher" : "https://www.facebook.com/thethudonline" />
     # "article:author" content="https://www.facebook.com/thethudonline" />
     # "article:tag" content="baby" />
@@ -196,6 +259,21 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
     # "article:published_time" content="2016-03-31T19:58:24+00:00" />
     # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
 
+    # Search for Twitter Card (twitter:) meta tags, e.g.:
+    #
+    # "twitter:site"    : "@matrixdotorg"
+    # "twitter:creator" : "@matrixdotorg"
+    #
+    # Twitter cards tags also duplicate Open Graph tags.
+    #
+    # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started
+    twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph)
+    # Merge the Twitter values with the Open Graph values, but do not overwrite
+    # information from Open Graph tags.
+    for key, value in twitter.items():
+        if key not in og:
+            og[key] = value
+
     if "og:title" not in og:
         # Attempt to find a title from the title tag, or the biggest header on the page.
         title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
@@ -276,7 +354,7 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
 
     from lxml import etree
 
-    TAGS_TO_REMOVE = (
+    TAGS_TO_REMOVE = {
         "header",
         "nav",
         "aside",
@@ -291,31 +369,42 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
         "img",
         "picture",
         etree.Comment,
-    )
+    }
 
     # Split all the text nodes into paragraphs (by splitting on new
     # lines)
     text_nodes = (
         re.sub(r"\s+", "\n", el).strip()
-        for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+        for el in _iterate_over_text(tree.find("body"), TAGS_TO_REMOVE)
     )
     return summarize_paragraphs(text_nodes)
 
 
 def _iterate_over_text(
-    tree: "etree.Element", *tags_to_ignore: Union[str, "etree.Comment"]
+    tree: Optional["etree.Element"],
+    tags_to_ignore: Set[Union[str, "etree.Comment"]],
+    stack_limit: int = 1024,
 ) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
+
+    Args:
+        tree: The parent element to iterate. Can be None if there isn't one.
+        tags_to_ignore: Set of tags to ignore
+        stack_limit: Maximum stack size limit for depth-first traversal.
+            Nodes will be dropped if this limit is hit, which may truncate the
+            textual result.
+            Intended to limit the maximum working memory when generating a preview.
     """
-    # This is basically a stack that we extend using itertools.chain.
-    # This will either consist of an element to iterate over *or* a string
+
+    if tree is None:
+        return
+
+    # This is a stack whose items are elements to iterate over *or* strings
     # to be returned.
-    elements = iter([tree])
-    while True:
-        el = next(elements, None)
-        if el is None:
-            return
+    elements: List[Union[str, "etree.Element"]] = [tree]
+    while elements:
+        el = elements.pop()
 
         if isinstance(el, str):
             yield el
@@ -329,17 +418,22 @@ def _iterate_over_text(
             if el.text:
                 yield el.text
 
-            # We add to the stack all the elements children, interspersed with
-            # each child's tail text (if it exists). The tail text of a node
-            # is text that comes *after* the node, so we always include it even
-            # if we ignore the child node.
-            elements = itertools.chain(
-                itertools.chain.from_iterable(  # Basically a flatmap
-                    [child, child.tail] if child.tail else [child]
-                    for child in el.iterchildren()
-                ),
-                elements,
-            )
+            # We add to the stack all the element's children, interspersed with
+            # each child's tail text (if it exists).
+            #
+            # We iterate in reverse order so that earlier pieces of text appear
+            # closer to the top of the stack.
+            for child in el.iterchildren(reversed=True):
+                if len(elements) > stack_limit:
+                    # We've hit our limit for working memory
+                    break
+
+                if child.tail:
+                    # The tail text of a node is text that comes *after* the node,
+                    # so we always include it even if we ignore the child node.
+                    elements.append(child.tail)
+
+                elements.append(child)
 
 
 def summarize_paragraphs(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 54a849eac9..a8f6fd6b35 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -109,10 +109,64 @@ class MediaInfo:
 
 class PreviewUrlResource(DirectServeJsonResource):
     """
-    Generating URL previews is a complicated task which many potential pitfalls.
-
-    See docs/development/url_previews.md for discussion of the design and
-    algorithm followed in this module.
+    The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
+    for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+    specific additions).
+
+    This does have trade-offs compared to other designs:
+
+    * Pros:
+      * Simple and flexible; can be used by any clients at any point
+    * Cons:
+      * If each homeserver provides one of these independently, all the homeservers in a
+        room may needlessly DoS the target URI
+      * The URL metadata must be stored somewhere, rather than just using Matrix
+        itself to store the media.
+      * Matrix cannot be used to distribute the metadata between homeservers.
+
+    When Synapse is asked to preview a URL it does the following:
+
+    1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the
+       config).
+    2. Checks the URL against an in-memory cache and returns the result if it exists. (This
+       is also used to de-duplicate processing of multiple in-flight requests at once.)
+    3. Kicks off a background process to generate a preview:
+       1. Checks URL and timestamp against the database cache and returns the result if it
+          has not expired and was successful (a 2xx return code).
+       2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it
+          does, update the URL to download.
+       3. Downloads the URL and stores it into a file via the media storage provider
+          and saves the local media metadata.
+       4. If the media is an image:
+          1. Generates thumbnails.
+          2. Generates an Open Graph response based on image properties.
+       5. If the media is HTML:
+          1. Decodes the HTML via the stored file.
+          2. Generates an Open Graph response from the HTML.
+          3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+             1. Downloads the URL and stores it into a file via the media storage provider
+                and saves the local media metadata.
+             2. Convert the oEmbed response to an Open Graph response.
+             3. Override any Open Graph data from the HTML with data from oEmbed.
+          4. If an image exists in the Open Graph response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       6. If the media is JSON and an oEmbed URL was found:
+          1. Convert the oEmbed response to an Open Graph response.
+          2. If a thumbnail or image is in the oEmbed response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       7. Stores the result in the database cache.
+    4. Returns the result.
+
+    The in-memory cache expires after 1 hour.
+
+    Expired entries in the database cache (and their associated media files) are
+    deleted every 10 seconds. The default expiration time is 1 hour from download.
     """
 
     isLeaf = True
@@ -678,10 +732,6 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         logger.debug("Running url preview cache expiry")
 
-        if not (await self.store.db_pool.updates.has_completed_background_updates()):
-            logger.debug("Still running DB updates; skipping url preview cache expiry")
-            return
-
         def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
             """Attempt to remove the given chain of parent directories
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 53b1565243..5f725c7600 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,8 +17,14 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
+from synapse.http.server import (
+    DirectServeJsonResource,
+    respond_with_json,
+    set_corp_headers,
+    set_cors_headers,
+)
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -58,6 +64,7 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
         set_cors_headers(request)
+        set_corp_headers(request)
         server_name, media_id, _ = parse_media_id(request)
         width = parse_integer(request, "width", required=True)
         height = parse_integer(request, "height", required=True)
@@ -304,6 +311,19 @@ class ThumbnailResource(DirectServeJsonResource):
             url_cache: True if this is from a URL cache.
             server_name: The server name, if this is a remote thumbnail.
         """
+        logger.debug(
+            "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            thumbnail_infos,
+        )
+
+        # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+        # different code path to handle it.
+        assert not self.dynamic_thumbnails
+
         if thumbnail_infos:
             file_info = self._select_thumbnail(
                 desired_width,
@@ -379,8 +399,29 @@ class ThumbnailResource(DirectServeJsonResource):
                 file_info.thumbnail.length,
             )
         else:
+            # This might be because:
+            # 1. We can't create thumbnails for the given media (corrupted or
+            #    unsupported file type), or
+            # 2. The thumbnailing process never ran or errored out initially
+            #    when the media was first uploaded (these bugs should be
+            #    reported and fixed).
+            # Note that we don't attempt to generate a thumbnail now because
+            # `dynamic_thumbnails` is disabled.
             logger.info("Failed to find any generated thumbnails")
-            respond_404(request)
+
+            respond_with_json(
+                request,
+                400,
+                cs_error(
+                    "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+                    % (
+                        request.postpath,
+                        ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+                    ),
+                    code=Codes.UNKNOWN,
+                ),
+                send_cors=True,
+            )
 
     def _select_thumbnail(
         self,
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
new file mode 100644
index 0000000000..ac39cda8e5
--- /dev/null
+++ b/synapse/rest/models.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Extra
+
+
+class RequestBodyModel(BaseModel):
+    """A custom version of Pydantic's BaseModel which
+
+     - ignores unknown fields and
+     - does not allow fields to be overwritten after construction,
+
+    but otherwise uses Pydantic's default behaviour.
+
+    Ignoring unknown fields is a useful default. It means that clients can provide
+    unstable field not known to the server without the request being refused outright.
+
+    Subclassing in this way is recommended by
+    https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+    """
+
+    class Config:
+        # By default, ignore fields that we don't recognise.
+        extra = Extra.ignore
+        # By default, don't allow fields to be reassigned after parsing.
+        allow_mutation = False
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index 6ad558f5d1..e55924f597 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -20,6 +20,7 @@ from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
 from synapse.rest.synapse.client.pick_idp import PickIdpResource
 from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -41,6 +42,8 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
         "/_synapse/client/pick_username": pick_username_resource(hs),
         "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
         "/_synapse/client/sso_register": SsoRegisterResource(hs),
+        # Unsubscribe to notification emails link
+        "/_synapse/client/unsubscribe": UnsubscribeResource(hs),
     }
 
     # provider-specific SSO bits. Only load these if they are enabled, since they
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 6ac9dbc7c9..b9402cfb75 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.api.errors import ThreepidValidationError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.server import DirectServeHtmlResource
 from synapse.http.servlet import parse_string
 from synapse.util.stringutils import assert_valid_client_secret
@@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        self._local_threepid_handling_disabled_due_to_email_config = (
-            hs.config.email.local_threepid_handling_disabled_due_to_email_config
-        )
         self._confirmation_email_template = (
             hs.config.email.email_password_reset_template_confirmation_html
         )
@@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
             hs.config.email.email_password_reset_template_failure_html
         )
 
-        # This resource should not be mounted if threepid behaviour is not LOCAL
-        assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+        # This resource should only be mounted if email validation is enabled
+        assert hs.config.email.can_verify_email
 
     async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
         sid = parse_string(request, "sid", required=True)
diff --git a/synapse/rest/synapse/client/unsubscribe.py b/synapse/rest/synapse/client/unsubscribe.py
new file mode 100644
index 0000000000..60321018f9
--- /dev/null
+++ b/synapse/rest/synapse/client/unsubscribe.py
@@ -0,0 +1,64 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from synapse.api.errors import StoreError
+from synapse.http.server import DirectServeHtmlResource, respond_with_html_bytes
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class UnsubscribeResource(DirectServeHtmlResource):
+    """
+    To allow pusher to be delete by clicking a link (ie. GET request)
+    """
+
+    SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.notifier = hs.get_notifier()
+        self.auth = hs.get_auth()
+        self.pusher_pool = hs.get_pusherpool()
+        self.macaroon_generator = hs.get_macaroon_generator()
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        token = parse_string(request, "access_token", required=True)
+        app_id = parse_string(request, "app_id", required=True)
+        pushkey = parse_string(request, "pushkey", required=True)
+
+        user_id = self.macaroon_generator.verify_delete_pusher_token(
+            token, app_id, pushkey
+        )
+
+        try:
+            await self.pusher_pool.remove_pusher(
+                app_id=app_id, pushkey=pushkey, user_id=user_id
+            )
+        except StoreError as se:
+            if se.code != 404:
+                # This is fine: they're already unsubscribed
+                raise
+
+        self.notifier.on_new_replication_data()
+
+        respond_with_html_bytes(
+            request,
+            200,
+            UnsubscribeResource.SUCCESS_HTML,
+        )
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 04b035a1b1..6f7ac54c65 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Optional
 
@@ -44,6 +43,14 @@ class WellKnownBuilder:
                 "base_url": self._config.registration.default_identity_server
             }
 
+        if self._config.server.extra_well_known_client_content:
+            for (
+                key,
+                value,
+            ) in self._config.server.extra_well_known_client_content.items():
+                if key not in result:
+                    result[key] = value
+
         return result
 
 
diff --git a/synapse/server.py b/synapse/server.py
index a66ec228db..c2e55bf0b1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -29,6 +29,7 @@ from twisted.web.iweb import IPolicyForHTTPS
 from twisted.web.resource import Resource
 
 from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.filtering import Filtering
 from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
 from synapse.appservice.api import ApplicationServiceApi
@@ -55,7 +56,7 @@ from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.admin import AdminHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
+from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
 from synapse.handlers.cas import CasHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@@ -129,6 +130,7 @@ from synapse.streams.events import EventSources
 from synapse.types import DomainSpecificString, ISynapseReactor
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
+from synapse.util.macaroons import MacaroonGenerator
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import random_string
 
@@ -380,6 +382,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return Auth(self)
 
     @cache_in_self
+    def get_auth_blocking(self) -> AuthBlocking:
+        return AuthBlocking(self)
+
+    @cache_in_self
     def get_http_client_context_factory(self) -> IPolicyForHTTPS:
         if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
             return InsecureInterceptableContextFactory()
@@ -487,7 +493,9 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_macaroon_generator(self) -> MacaroonGenerator:
-        return MacaroonGenerator(self)
+        return MacaroonGenerator(
+            self.get_clock(), self.hostname, self.config.key.macaroon_secret_key
+        )
 
     @cache_in_self
     def get_device_handler(self):
@@ -748,7 +756,9 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_federation_ratelimiter(self) -> FederationRateLimiter:
         return FederationRateLimiter(
-            self.get_clock(), config=self.config.ratelimiting.rc_federation
+            self.get_clock(),
+            config=self.config.ratelimiting.rc_federation,
+            metrics_name="federation_servlets",
         )
 
     @cache_in_self
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 6863020778..3134cd2d3d 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -37,7 +37,7 @@ class ResourceLimitsServerNotices:
         self._server_notices_manager = hs.get_server_notices_manager()
         self._store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
-        self._auth = hs.get_auth()
+        self._auth_blocking = hs.get_auth_blocking()
         self._config = hs.config
         self._resouce_limited = False
         self._account_data_handler = hs.get_account_data_handler()
@@ -91,7 +91,7 @@ class ResourceLimitsServerNotices:
             # Normally should always pass in user_id to check_auth_blocking
             # if you have it, but in this case are checking what would happen
             # to other users if they were to arrive.
-            await self._auth.check_auth_blocking()
+            await self._auth_blocking.check_auth_blocking()
         except ResourceLimitError as e:
             limit_msg = e.msg
             limit_type = e.limit_type
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 8ecab86ec7..564e3705c2 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -102,6 +102,10 @@ class ServerNoticesManager:
         Returns:
             The room's ID, or None if no room could be found.
         """
+        # If there is no server notices MXID, then there is no server notices room
+        if self.server_notices_mxid is None:
+            return None
+
         rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE, Membership.JOIN]
         )
@@ -111,8 +115,10 @@ class ServerNoticesManager:
             # be joined. This is kinda deliberate, in that if somebody somehow
             # manages to invite the system user to a room, that doesn't make it
             # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+            is_server_notices_room = await self._store.check_local_user_in_room(
+                user_id=self.server_notices_mxid, room_id=room.room_id
+            )
+            if is_server_notices_room:
                 # we found a room which our user shares with the system notice
                 # user
                 return room.room_id
@@ -244,7 +250,7 @@ class ServerNoticesManager:
         assert self.server_notices_mxid is not None
 
         notice_user_data_in_room = await self._message_handler.get_room_data(
-            self.server_notices_mxid,
+            create_requester(self.server_notices_mxid),
             room_id,
             EventTypes.Member,
             self.server_notices_mxid,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index da25f20ae5..3787d35b24 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -24,14 +24,12 @@ from typing import (
     DefaultDict,
     Dict,
     FrozenSet,
-    Iterable,
     List,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
-    Union,
 )
 
 import attr
@@ -43,9 +41,10 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.context import ContextResourceUsage
+from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -53,6 +52,7 @@ from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.controllers import StateStorageController
     from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
@@ -82,17 +82,26 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
-        state: StateMap[str],
+        state: Optional[StateMap[str]],
         state_group: Optional[int],
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
+        if state is None and state_group is None and prev_group is None:
+            raise Exception("One of state, state_group or prev_group must be not None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("If prev_group is set so must delta_ids")
+
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state)
+        #
+        # This can be None if we have a `state_group` (as then we can fetch the
+        # state from the DB.)
+        self._state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -101,20 +110,60 @@ class _StateCacheEntry:
         self.prev_group = prev_group
         self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
-        # The `state_id` is a unique ID we generate that can be used as ID for
-        # this collection of state. Usually this would be the same as the
-        # state group, but on worker instances we can't generate a new state
-        # group each time we resolve state, so we generate a separate one that
-        # isn't persisted and is used solely for caches.
-        # `state_id` is either a state_group (and so an int) or a string. This
-        # ensures we don't accidentally persist a state_id as a stateg_group
-        if state_group:
-            self.state_id: Union[str, int] = state_group
-        else:
-            self.state_id = _gen_state_id()
+    async def get_state(
+        self,
+        state_storage: "StateStorageController",
+        state_filter: Optional["StateFilter"] = None,
+    ) -> StateMap[str]:
+        """Get the state map for this entry, either from the in-memory state or
+        looking up the state group in the DB.
+        """
+
+        if self._state is not None:
+            return self._state
+
+        if self.state_group is not None:
+            return await state_storage.get_state_ids_for_group(
+                self.state_group, state_filter
+            )
+
+        assert self.prev_group is not None and self.delta_ids is not None
+
+        prev_state = await state_storage.get_state_ids_for_group(
+            self.prev_group, state_filter
+        )
+
+        # ChainMap expects MutableMapping, but since we're using it immutably
+        # its safe to give it immutable maps.
+        return ChainMap(self.delta_ids, prev_state)  # type: ignore[arg-type]
+
+    def set_state_group(self, state_group: int) -> None:
+        """Update the state group assigned to this state (e.g. after we've
+        persisted it).
+
+        Note: this will cause the cache entry to drop any stored state.
+        """
+
+        self.state_group = state_group
+
+        # We clear out the state as we know longer need to explicitly keep it in
+        # the `state_cache` (as the store state group cache will do that).
+        self._state = None
 
     def __len__(self) -> int:
-        return len(self.state)
+        # The len should be used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why it's fine to return 1 if we're
+        # not storing any state.
+
+        length = 0
+
+        if self._state:
+            length += len(self._state)
+
+        if self.delta_ids:
+            length += len(self.delta_ids)
+
+        return length or 1  # Make sure its not 0.
 
 
 class StateHandler:
@@ -129,30 +178,42 @@ class StateHandler:
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._storage_controllers = hs.get_storage_controllers()
+        self._events_shard_config = hs.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
 
-    async def get_current_state_ids(
+        self._update_current_state_client = (
+            ReplicationUpdateCurrentStateRestServlet.make_client(hs)
+        )
+
+    async def compute_state_after_events(
         self,
         room_id: str,
-        latest_event_ids: Collection[str],
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
     ) -> StateMap[str]:
-        """Get the current state, or the state at a set of events, for a room
+        """Fetch the state after each of the given event IDs. Resolve them and return.
+
+        This is typically used where `event_ids` is a collection of forward extremities
+        in a room, intended to become the `prev_events` of a new event E. If so, the
+        return value of this function represents the state before E.
 
         Args:
-            room_id:
-            latest_event_ids: The forward extremities to resolve.
+            room_id: the room_id containing the given events.
+            event_ids: the events whose state should be fetched and resolved.
 
         Returns:
-            the state dict, mapping from (event_type, state_key) -> event_id
+            the state dict (a mapping from (event_type, state_key) -> event_id) which
+            holds the resolution of the states after the given event IDs.
         """
-        logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return ret.state
+        logger.debug("calling resolve_state_groups from compute_state_after_events")
+        ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+        return await ret.get_state(self._state_storage_controller, state_filter)
 
-    async def get_current_users_in_room(
+    async def get_current_user_ids_in_room(
         self, room_id: str, latest_event_ids: List[str]
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         """
-        Get the users who are currently in a room.
+        Get the users IDs who are currently in a room.
 
         Note: This is much slower than using the equivalent method
         `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -163,14 +224,15 @@ class StateHandler:
             room_id: The ID of the room.
             latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
         Returns:
-            Dictionary of user IDs to their profileinfo.
+            Set of user IDs in the room.
         """
 
         assert latest_event_ids is not None
 
-        logger.debug("calling resolve_state_groups from get_current_users_in_room")
+        logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return await self.store.get_joined_users_from_state(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_user_ids_from_state(room_id, state)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
@@ -185,13 +247,14 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        return await self.store.get_joined_hosts(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_hosts(room_id, state, entry)
 
     async def compute_event_context(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
-        partial_state: bool = False,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
 
@@ -206,10 +269,18 @@ class StateHandler:
                 it can't be calculated from existing events. This is normally
                 only specified when receiving an event from federation where we
                 don't have the prev events, e.g. when backfilling.
-            partial_state: True if `state_ids_before_event` is partial and omits
-                non-critical membership events
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
         Returns:
             The event context.
+
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
         """
 
         assert not event.internal_metadata.is_outlier()
@@ -220,17 +291,28 @@ class StateHandler:
         #
         if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
-            entry = None
 
+            # .. though we need to get a state group for it.
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=None,
+                    delta_ids=None,
+                    current_state_ids=state_ids_before_event,
+                )
+            )
+
+            # the partial_state flag must be provided
+            assert partial_state is not None
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
 
             # partial_state should not be set explicitly in this case:
             # we work it out dynamically
-            assert not partial_state
+            assert partial_state is None
 
             # if any of the prev-events have partial state, so do we.
             # (This is slightly racy - the prev-events might get fixed up before we use
@@ -240,49 +322,49 @@ class StateHandler:
             incomplete_prev_events = await self.store.get_partial_state_events(
                 prev_event_ids
             )
-            if any(incomplete_prev_events.values()):
+            partial_state = any(incomplete_prev_events.values())
+            if partial_state:
                 logger.debug(
                     "New/incoming event %s refers to prev_events %s with partial state",
                     event.event_id,
                     [k for (k, v) in incomplete_prev_events.items() if v],
                 )
-                partial_state = True
 
             logger.debug("calling resolve_state_groups from compute_event_context")
+            # we've already taken into account partial state, so no need to wait for
+            # complete state here.
             entry = await self.resolve_state_groups_for_events(
-                event.room_id, event.prev_event_ids()
+                event.room_id,
+                event.prev_event_ids(),
+                await_full_state=False,
             )
 
-            state_ids_before_event = entry.state
-            state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
+
+            # We make sure that we have a state group assigned to the state.
+            if entry.state_group is None:
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
 
-        #
-        # make sure that we have a state group at that point. If it's not a state event,
-        # that will be the state group for the new event. If it *is* a state event,
-        # it might get rejected (in which case we'll need to persist it with the
-        # previous state group)
-        #
-
-        if not state_group_before_event:
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=state_group_before_event_prev_group,
-                    delta_ids=deltas_to_state_group_before_event,
-                    current_state_ids=state_ids_before_event,
+                state_group_before_event = (
+                    await self._state_storage_controller.store_state_group(
+                        event.event_id,
+                        event.room_id,
+                        prev_group=state_group_before_event_prev_group,
+                        delta_ids=deltas_to_state_group_before_event,
+                        current_state_ids=state_ids_before_event,
+                    )
                 )
-            )
-
-            # Assign the new state group to the cached state entry.
-            #
-            # Note that this can race in that we could generate multiple state
-            # groups for the same state entry, but that is just inefficient
-            # rather than dangerous.
-            if entry and entry.state_group is None:
-                entry.state_group = state_group_before_event
+                entry.set_state_group(state_group_before_event)
+            else:
+                state_group_before_event = entry.state_group
 
         #
         # now if it's not a state event, we're done
@@ -304,13 +386,18 @@ class StateHandler:
         #
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -319,7 +406,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )
 
@@ -335,7 +422,7 @@ class StateHandler:
 
     @measure_func()
     async def resolve_state_groups_for_events(
-        self, room_id: str, event_ids: Collection[str]
+        self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> _StateCacheEntry:
         """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
@@ -343,14 +430,20 @@ class StateHandler:
         Args:
             room_id
             event_ids
+            await_full_state: if true, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             The resolved state
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
         state_groups = await self._state_storage_controller.get_state_group_for_events(
-            event_ids
+            event_ids, await_full_state=await_full_state
         )
 
         state_group_ids = state_groups.values()
@@ -359,9 +452,6 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self._state_storage_controller.get_state_for_groups(
-                state_group_ids_set
-            )
             (
                 prev_group,
                 delta_ids,
@@ -369,7 +459,7 @@ class StateHandler:
                 state_group_id
             )
             return _StateCacheEntry(
-                state=state[state_group_id],
+                state=None,
                 state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
@@ -392,30 +482,23 @@ class StateHandler:
         )
         return result
 
-    async def resolve_events(
-        self,
-        room_version: str,
-        state_sets: Collection[Iterable[EventBase]],
-        event: EventBase,
-    ) -> StateMap[EventBase]:
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [
-            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
-        ]
-
-        state_map = {ev.event_id: ev for st in state_sets for ev in st}
+    async def update_current_state(self, room_id: str) -> None:
+        """Recalculates the current state for a room, and persists it.
 
-        new_state = await self._state_resolution_handler.resolve_events_with_store(
-            event.room_id,
-            room_version,
-            state_set_ids,
-            event_map=state_map,
-            state_res_store=StateResolutionStore(self.store),
-        )
+        Raises:
+            SynapseError(502): if all attempts to connect to the event persister worker
+                fail
+        """
+        writer_instance = self._events_shard_config.get_instance(room_id)
+        if writer_instance != self._instance_name:
+            await self._update_current_state_client(
+                instance_name=writer_instance,
+                room_id=room_id,
+            )
+            return
 
-        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
+        assert self._storage_controllers.persistence is not None
+        await self._storage_controllers.persistence.update_current_state(room_id)
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -444,6 +527,15 @@ _biggest_room_by_db_counter = Counter(
     "expensive room for state resolution",
 )
 
+_cpu_times = Histogram(
+    "synapse_state_res_cpu_for_all_rooms_seconds",
+    "CPU time (utime+stime) spent computing a single state resolution",
+)
+_db_times = Histogram(
+    "synapse_state_res_db_for_all_rooms_seconds",
+    "Database time spent computing a single state resolution",
+)
+
 
 class StateResolutionHandler:
     """Responsible for doing state conflict resolution.
@@ -609,6 +701,9 @@ class StateResolutionHandler:
         room_metrics.db_time += rusage.db_txn_duration_sec
         room_metrics.db_events += rusage.evt_db_fetch_count
 
+        _cpu_times.observe(rusage.ru_utime + rusage.ru_stime)
+        _db_times.observe(rusage.db_txn_duration_sec)
+
     def _report_metrics(self) -> None:
         if not self._state_res_metrics:
             # no state res has happened since the last iteration: don't bother logging.
@@ -698,7 +793,7 @@ def _make_state_cache_entry(
         old_state_event_ids = set(state.values())
         if new_state_event_ids == old_state_event_ids:
             # got an exact match.
-            return _StateCacheEntry(state=new_state, state_group=sg)
+            return _StateCacheEntry(state=None, state_group=sg)
 
     # TODO: We want to create a state group for this set of events, to
     # increase cache hits, but we need to make sure that it doesn't
@@ -709,14 +804,25 @@ def _make_state_cache_entry(
     delta_ids: Optional[StateMap[str]] = None
 
     for old_group, old_state in state_groups_ids.items():
+        if old_state.keys() - new_state.keys():
+            # Currently we don't support deltas that remove keys from the state
+            # map, so we have to ignore this group as a candidate to base the
+            # new group on.
+            continue
+
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
             delta_ids = n_delta_ids
 
-    return _StateCacheEntry(
-        state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
-    )
+    if prev_group is not None:
+        # If we have a prev group and deltas then we can drop the new state from
+        # the cache (to reduce memory usage).
+        return _StateCacheEntry(
+            state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+        )
+    else:
+        return _StateCacheEntry(state=new_state, state_group=None)
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 499a328201..500e384695 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -30,7 +30,7 @@ from typing import (
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
@@ -330,8 +330,7 @@ def _resolve_auth_events(
         auth_events[(prev_event.type, prev_event.state_key)] = prev_event
         try:
             # The signatures have already been checked at this point
-            event_auth.check_auth_rules_for_event(
-                RoomVersions.V1,
+            event_auth.check_state_dependent_auth_rules(
                 event,
                 auth_events.values(),
             )
@@ -348,8 +347,7 @@ def _resolve_normal_events(
     for event in _ordered_events(events):
         try:
             # The signatures have already been checked at this point
-            event_auth.check_auth_rules_for_event(
-                RoomVersions.V1,
+            event_auth.check_state_dependent_auth_rules(
                 event,
                 auth_events.values(),
             )
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index c618df2fde..af03851c71 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -17,12 +17,14 @@ import itertools
 import logging
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
     Generator,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -30,33 +32,58 @@ from typing import (
     overload,
 )
 
-from typing_extensions import Literal
+from typing_extensions import Literal, Protocol
 
-import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
-from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
 
+class Clock(Protocol):
+    # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
+    # We only ever sleep(0) though, so that other async functions can make forward
+    # progress without waiting for stateres to complete.
+    def sleep(self, duration_ms: float) -> Awaitable[None]:
+        ...
+
+
+class StateResolutionStore(Protocol):
+    # This is usually synapse.state.StateResolutionStore, but it's replaced with a
+    # TestStateResolutionStore in tests.
+    def get_events(
+        self, event_ids: Collection[str], allow_rejected: bool = False
+    ) -> Awaitable[Dict[str, EventBase]]:
+        ...
+
+    def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Awaitable[Set[str]]:
+        ...
+
+
 # We want to await to the reactor occasionally during state res when dealing
 # with large data sets, so that we don't exhaust the reactor. This is done by
 # awaiting to reactor during loops every N iterations.
 _AWAIT_AFTER_ITERATIONS = 100
 
 
+__all__ = [
+    "resolve_events_with_store",
+]
+
+
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
@@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
@@ -243,41 +270,42 @@ async def _get_power_level_for_sender(
 
 async def _get_auth_chain_difference(
     room_id: str,
-    state_sets: Sequence[StateMap[str]],
-    event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_sets: Sequence[Mapping[Any, str]],
+    unpersisted_events: Dict[str, EventBase],
+    state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
-    that only appear in some but not all of the auth chains.
+    that only appear in some, but not all of the auth chains.
 
     Args:
-        state_sets
-        event_map
-        state_res_store
+        state_sets: The input state sets we are trying to resolve across.
+        unpersisted_events: A map from event ID to EventBase containing all unpersisted
+            events involved in this resolution.
+        state_res_store:
 
     Returns:
-        Set of event IDs
+        The auth difference of the given state sets, as a set of event IDs.
     """
 
     # The `StateResolutionStore.get_auth_chain_difference` function assumes that
     # all events passed to it (and their auth chains) have been persisted
-    # previously. This is not the case for any events in the `event_map`, and so
-    # we need to manually handle those events.
+    # previously. We need to manually handle any other events that are yet to be
+    # persisted.
     #
-    # We do this by:
-    #   1. calculating the auth chain difference for the state sets based on the
-    #      events in `event_map` alone
-    #   2. replacing any events in the state_sets that are also in `event_map`
-    #      with their auth events (recursively), and then calling
-    #      `store.get_auth_chain_difference` as normal
-    #   3. adding the results of 1 and 2 together.
-
-    # Map from event ID in `event_map` to their auth event IDs, and their auth
-    # event IDs if they appear in the `event_map`. This is the intersection of
-    # the event's auth chain with the events in the `event_map` *plus* their
+    # We do this in three steps:
+    #   1. Compute the set of unpersisted events belonging to the auth difference.
+    #   2. Replacing any unpersisted events in the state_sets with their auth events,
+    #      recursively, until the state_sets contain only persisted events.
+    #      Then we call `store.get_auth_chain_difference` as normal, which computes
+    #      the set of persisted events belonging to the auth difference.
+    #   3. Adding the results of 1 and 2 together.
+
+    # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
+    # event IDs if they appear in the `unpersisted_events`. This is the intersection of
+    # the event's auth chain with the events in `unpersisted_events` *plus* their
     # auth event IDs.
     events_to_auth_chain: Dict[str, Set[str]] = {}
-    for event in event_map.values():
+    for event in unpersisted_events.values():
         chain = {event.event_id}
         events_to_auth_chain[event.event_id] = chain
 
@@ -285,16 +313,16 @@ async def _get_auth_chain_difference(
         while to_search:
             for auth_id in to_search.pop().auth_event_ids():
                 chain.add(auth_id)
-                auth_event = event_map.get(auth_id)
+                auth_event = unpersisted_events.get(auth_id)
                 if auth_event:
                     to_search.append(auth_event)
 
-    # We now a) calculate the auth chain difference for the unpersisted events
-    # and b) work out the state sets to pass to the store.
+    # We now 1) calculate the auth chain difference for the unpersisted events
+    # and 2) work out the state sets to pass to the store.
     #
-    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # Note: If there are no `unpersisted_events` (which is the common case), we can do a
     # much simpler calculation.
-    if event_map:
+    if unpersisted_events:
         # The list of state sets to pass to the store, where each state set is a set
         # of the event ids making up the state. This is similar to `state_sets`,
         # except that (a) we only have event ids, not the complete
@@ -317,14 +345,18 @@ async def _get_auth_chain_difference(
             for event_id in state_set.values():
                 event_chain = events_to_auth_chain.get(event_id)
                 if event_chain is not None:
-                    # We have an event in `event_map`. We add all the auth
-                    # events that it references (that aren't also in `event_map`).
-                    set_ids.update(e for e in event_chain if e not in event_map)
+                    # We have an unpersisted event. We add all the auth
+                    # events that it references which are also unpersisted.
+                    set_ids.update(
+                        e for e in event_chain if e not in unpersisted_events
+                    )
 
                     # We also add the full chain of unpersisted event IDs
                     # referenced by this state set, so that we can work out the
                     # auth chain difference of the unpersisted events.
-                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                    unpersisted_ids.update(
+                        e for e in event_chain if e in unpersisted_events
+                    )
                 else:
                     set_ids.add(event_id)
 
@@ -334,15 +366,15 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        difference_from_event_map: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: Collection[str] = union - intersection
     else:
-        difference_from_event_map = ()
+        auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
     difference = await state_res_store.get_auth_chain_difference(
         room_id, state_sets_ids
     )
-    difference.update(difference_from_event_map)
+    difference.update(auth_difference_unpersisted_part)
 
     return difference
 
@@ -406,8 +438,8 @@ async def _add_event_and_auth_chain_to_graph(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
-    auth_diff: Set[str],
+    state_res_store: StateResolutionStore,
+    full_conflicted_set: Set[str],
 ) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
@@ -418,7 +450,7 @@ async def _add_event_and_auth_chain_to_graph(
         event_id: Event to add to the graph
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
     """
 
     state = [event_id]
@@ -428,7 +460,7 @@ async def _add_event_and_auth_chain_to_graph(
 
         event = await _get_event(room_id, eid, event_map, state_res_store)
         for aid in event.auth_event_ids():
-            if aid in auth_diff:
+            if aid in full_conflicted_set:
                 if aid not in graph:
                     state.append(aid)
 
@@ -440,8 +472,8 @@ async def _reverse_topological_power_sort(
     room_id: str,
     event_ids: Iterable[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
-    auth_diff: Set[str],
+    state_res_store: StateResolutionStore,
+    full_conflicted_set: Set[str],
 ) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
@@ -452,7 +484,7 @@ async def _reverse_topological_power_sort(
         event_ids: The events to sort
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
 
     Returns:
         The sorted list
@@ -461,7 +493,7 @@ async def _reverse_topological_power_sort(
     graph: Dict[str, Set[str]] = {}
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
-            graph, room_id, event_id, event_map, state_res_store, auth_diff
+            graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
         )
 
         # We await occasionally when we're working with large data sets to
@@ -501,7 +533,7 @@ async def _iterative_auth_checks(
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> MutableStateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
@@ -546,8 +578,7 @@ async def _iterative_auth_checks(
                     auth_events[key] = event_map[ev_id]
 
         try:
-            event_auth.check_auth_rules_for_event(
-                room_version,
+            event_auth.check_state_dependent_auth_rules(
                 event,
                 auth_events.values(),
             )
@@ -570,7 +601,7 @@ async def _mainline_sort(
     event_ids: List[str],
     resolved_power_event_id: Optional[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
@@ -639,7 +670,7 @@ async def _get_mainline_depth_for_event(
     event: EventBase,
     mainline_map: Dict[str, int],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
@@ -683,7 +714,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[False] = False,
 ) -> EventBase:
     ...
@@ -694,7 +725,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[True],
 ) -> Optional[EventBase]:
     ...
@@ -704,7 +735,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: bool = False,
 ) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index 9e6daf38ac..40510889ac 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -3,7 +3,8 @@
 <head>
     <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
     <title> Login </title>
-    <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <link rel="stylesheet" href="style.css">
     <script src="js/jquery-3.4.1.min.js"></script>
     <script src="js/login.js"></script>
diff --git a/synapse/static/client/register/index.html b/synapse/static/client/register/index.html
index 140653574d..27bbd76f51 100644
--- a/synapse/static/client/register/index.html
+++ b/synapse/static/client/register/index.html
@@ -2,7 +2,8 @@
 <html>
 <head>
 <title> Registration </title>
-<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> 
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="style.css">
 <script src="js/jquery-3.4.1.min.js"></script>
 <script src="https://www.recaptcha.net/recaptcha/api/js/recaptcha_ajax.js"></script>
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abfc56b061..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -75,6 +75,19 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache(
                 "get_users_in_room_with_profiles", (room_id,)
             )
+            self._attempt_to_invalidate_cache(
+                "get_number_joined_users_in_room", (room_id,)
+            )
+            self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
+
+            # There's no easy way of invalidating this cache for just the users
+            # that have changed, so we just clear the entire thing.
+            self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
+        for user_id in members_changed:
+            self._attempt_to_invalidate_cache(
+                "get_user_in_room_with_profile", (room_id, user_id)
+            )
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
@@ -87,6 +100,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
+        Note that this function does not invalidate any remote caches, only the
+        local in-memory ones. Any remote invalidation must be performed before
+        calling this.
+
         Args:
             cache_name
             key: Entry to invalidate. If None then invalidates the entire
@@ -103,7 +120,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         if key is None:
             cache.invalidate_all()
         else:
-            cache.invalidate(tuple(key))
+            # Prefer any local-only invalidation method. Invalidating any non-local
+            # cache must be be done before this.
+            invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+            invalidate_method(tuple(key))
 
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b1e5208c76..555b4e77d2 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -507,25 +507,6 @@ class BackgroundUpdater:
             update_handler
         )
 
-    def register_noop_background_update(self, update_name: str) -> None:
-        """Register a noop handler for a background update.
-
-        This is useful when we previously did a background update, but no
-        longer wish to do the update. In this case the background update should
-        be removed from the schema delta files, but there may still be some
-        users who have the background update queued, so this method should
-        also be called to clear the update.
-
-        Args:
-            update_name: Name of update
-        """
-
-        async def noop_update(progress: JsonDict, batch_size: int) -> int:
-            await self._end_background_update(update_name)
-            return 1
-
-        self.register_background_update_handler(update_name, noop_update)
-
     def register_background_index_update(
         self,
         update_name: str,
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
 
         self.persistence = None
         if stores.persist_events:
-            self.persistence = EventsPersistenceStorageController(hs, stores)
+            self.persistence = EventsPersistenceStorageController(
+                hs, stores, self.state
+            )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 4caaa81808..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -22,6 +22,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    ClassVar,
     Collection,
     Deque,
     Dict,
@@ -33,6 +34,7 @@ from typing import (
     Set,
     Tuple,
     TypeVar,
+    Union,
 )
 
 import attr
@@ -43,12 +45,20 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    set_tag,
+    start_active_span_follows_from,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -111,9 +121,43 @@ times_pruned_extremities = Counter(
 
 
 @attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _PersistEventsTask:
+    """A batch of events to persist."""
+
+    name: ClassVar[str] = "persist_event_batch"  # used for opentracing
+
     events_and_contexts: List[Tuple[EventBase, EventContext]]
     backfilled: bool
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Batches events with the same backfilled option together."""
+        if (
+            not isinstance(task, _PersistEventsTask)
+            or self.backfilled != task.backfilled
+        ):
+            return False
+
+        self.events_and_contexts.extend(task.events_and_contexts)
+        return True
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _UpdateCurrentStateTask:
+    """A room whose current state needs recalculating."""
+
+    name: ClassVar[str] = "update_current_state"  # used for opentracing
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Deduplicates consecutive recalculations of current state."""
+        return isinstance(task, _UpdateCurrentStateTask)
+
+
+_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+    task: _EventPersistQueueTask
     deferred: ObservableDeferred
 
     parent_opentracing_span_contexts: List = attr.ib(factory=list)
@@ -127,14 +171,16 @@ _PersistResult = TypeVar("_PersistResult")
 
 
 class _EventPeristenceQueue(Generic[_PersistResult]):
-    """Queues up events so that they can be persisted in bulk with only one
-    concurrent transaction per room.
+    """Queues up tasks so that they can be processed with only one concurrent
+    transaction per room.
+
+    Tasks can be bulk persistence of events or recalculation of a room's current state.
     """
 
     def __init__(
         self,
         per_item_callback: Callable[
-            [List[Tuple[EventBase, EventContext]], bool],
+            [str, _EventPersistQueueTask],
             Awaitable[_PersistResult],
         ],
     ):
@@ -150,18 +196,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
     async def add_to_queue(
         self,
         room_id: str,
-        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        task: _EventPersistQueueTask,
     ) -> _PersistResult:
-        """Add events to the queue, with the given persist_event options.
+        """Add a task to the queue.
 
-        If we are not already processing events in this room, starts off a background
+        If we are not already processing tasks in this room, starts off a background
         process to to so, calling the per_item_callback for each item.
 
         Args:
             room_id (str):
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
+            task (_EventPersistQueueTask): A _PersistEventsTask or
+                _UpdateCurrentStateTask to process.
 
         Returns:
             the result returned by the `_per_item_callback` passed to
@@ -169,28 +214,22 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
-        # if the last item in the queue has the same `backfilled` setting,
-        # we can just add these new events to that item.
-        if queue and queue[-1].backfilled == backfilled:
+        if queue and queue[-1].task.try_merge(task):
+            # the new task has been merged into the last task in the queue
             end_item = queue[-1]
         else:
-            # need to make a new queue item
             deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
                 defer.Deferred(), consumeErrors=True
             )
 
             end_item = _EventPersistQueueItem(
-                events_and_contexts=[],
-                backfilled=backfilled,
+                task=task,
                 deferred=deferred,
             )
             queue.append(end_item)
 
-        # add our events to the queue item
-        end_item.events_and_contexts.extend(events_and_contexts)
-
         # also add our active opentracing span to the item so that we get a link back
-        span = opentracing.active_span()
+        span = active_span()
         if span:
             end_item.parent_opentracing_span_contexts.append(span.context)
 
@@ -201,8 +240,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         res = await make_deferred_yieldable(end_item.deferred.observe())
 
         # add another opentracing span which links to the persist trace.
-        with opentracing.start_active_span_follows_from(
-            "persist_event_batch_complete", (end_item.opentracing_span_context,)
+        with start_active_span_follows_from(
+            f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
 
@@ -233,17 +272,15 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        with opentracing.start_active_span_follows_from(
-                            "persist_event_batch",
+                        with start_active_span_follows_from(
+                            item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
                         ) as scope:
                             if scope:
                                 item.opentracing_span_context = scope.span.context
 
-                            ret = await self._per_item_callback(
-                                item.events_and_contexts, item.backfilled
-                            )
+                            ret = await self._per_item_callback(room_id, item.task)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -279,7 +316,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -292,10 +334,34 @@ class EventsPersistenceStorageController:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
-        self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+        self._event_persist_queue = _EventPeristenceQueue(
+            self._process_event_persist_queue_task
+        )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
-    @opentracing.trace
+    async def _process_event_persist_queue_task(
+        self,
+        room_id: str,
+        task: _EventPersistQueueTask,
+    ) -> Dict[str, str]:
+        """Callback for the _event_persist_queue
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
+        """
+        if isinstance(task, _PersistEventsTask):
+            return await self._persist_event_batch(room_id, task)
+        elif isinstance(task, _UpdateCurrentStateTask):
+            await self._update_current_state(room_id, task)
+            return {}
+        else:
+            raise AssertionError(
+                f"Found an unexpected task type in event persistence queue: {task}"
+            )
+
+    @trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -315,17 +381,34 @@ class EventsPersistenceStorageController:
             if they were deduplicated due to an event already existing that
             matched the transaction ID; the existing event is returned in such
             a case.
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
+        event_ids: List[str] = []
         partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
+            event_ids.append(event.event_id)
+
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
 
         async def enqueue(
             item: Tuple[str, List[Tuple[EventBase, EventContext]]]
         ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs, backfilled=backfilled
+                room_id,
+                _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
             )
 
         ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@@ -353,7 +436,7 @@ class EventsPersistenceStorageController:
             self.main_store.get_room_max_token(),
         )
 
-    @opentracing.trace
+    @trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
@@ -363,12 +446,19 @@ class EventsPersistenceStorageController:
             latest persisted event. The returned event may not match the given
             event if it was deduplicated due to an existing event matching the
             transaction ID.
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         # add_to_queue returns a map from event ID to existing event ID if the
         # event was deduplicated. (The dict may also include other entries if
         # the event was persisted in a batch with other events.)
         replaced_events = await self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)], backfilled=backfilled
+            event.room_id,
+            _PersistEventsTask(
+                events_and_contexts=[(event, context)], backfilled=backfilled
+            ),
         )
         replaced_event = replaced_events.get(event.event_id)
         if replaced_event:
@@ -383,17 +473,22 @@ class EventsPersistenceStorageController:
 
     async def update_current_state(self, room_id: str) -> None:
         """Recalculate the current state for a room, and persist it"""
+        await self._event_persist_queue.add_to_queue(
+            room_id,
+            _UpdateCurrentStateTask(),
+        )
+
+    async def _update_current_state(
+        self, room_id: str, _task: _UpdateCurrentStateTask
+    ) -> None:
+        """Callback for the _event_persist_queue
+
+        Recalculates the current state for a room, and persists it.
+        """
         state = await self._calculate_current_state(room_id)
         delta = await self._calculate_state_delta(room_id, state)
 
-        # TODO(faster_joins): get a real stream ordering, to make this work correctly
-        #    across workers.
-        #
-        # TODO(faster_joins): this can race against event persistence, in which case we
-        #    will end up with incorrect state. Perhaps we should make this a job we
-        #    farm out to the event persister, somehow.
-        stream_id = self.main_store.get_room_max_stream_ordering()
-        await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+        await self.persist_events_store.update_current_state(room_id, delta)
 
     async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
         """Calculate the current state of a room, based on the forward extremities
@@ -435,12 +530,10 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
-        self,
-        events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool = False,
+        self, _room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
@@ -450,7 +543,14 @@ class EventsPersistenceStorageController:
         Returns:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
+        events_and_contexts = task.events_and_contexts
+        backfilled = task.backfilled
+
         replaced_events: Dict[str, str] = {}
         if not events_and_contexts:
             return replaced_events
@@ -866,7 +966,8 @@ class EventsPersistenceStorageController:
                 events_context,
             )
 
-        return res.state, None, new_latest_event_ids
+        full_state = await res.get_state(self._state_controller)
+        return full_state, None, new_latest_event_ids
 
     async def _prune_extremities(
         self,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 3b4cdb67eb..ba5380ce3e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,12 +23,13 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
@@ -82,13 +83,15 @@ class StateStorageController:
         return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
-        self, _room_id: str, event_ids: Collection[str]
+        self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
             _room_id: id of the room for these events
             event_ids: ids of the events
+            await_full_state: if `True`, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +103,9 @@ class StateStorageController:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -175,6 +180,7 @@ class StateStorageController:
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
+    @trace
     async def get_state_for_events(
         self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
@@ -221,10 +227,13 @@ class StateStorageController:
 
         return {event: event_to_state[event] for event in event_ids}
 
+    @trace
+    @tag_args
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -233,6 +242,9 @@ class StateStorageController:
         Args:
             event_ids: events whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at these events and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from event_id -> (type, state_key) -> event_id
@@ -241,8 +253,12 @@ class StateStorageController:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
+            # Full state is not required if the state filter is restrictive enough.
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -283,8 +299,12 @@ class StateStorageController:
         )
         return state_map[event_id]
 
+    @trace
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -292,6 +312,9 @@ class StateStorageController:
         Args:
             event_id: event whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from (type, state_key) -> state_event_id
@@ -301,7 +324,9 @@ class StateStorageController:
                 outlier or is unknown)
         """
         state_map = await self.get_state_ids_for_events(
-            [event_id], state_filter or StateFilter.all()
+            [event_id],
+            state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
         return state_map[event_id]
 
@@ -323,6 +348,8 @@ class StateStorageController:
             groups, state_filter or StateFilter.all()
         )
 
+    @trace
+    @tag_args
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -334,6 +361,10 @@ class StateStorageController:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
                state at these events.
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)
@@ -346,7 +377,7 @@ class StateStorageController:
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
@@ -452,11 +483,15 @@ class StateStorageController:
                  up to date.
         """
         # FIXME(faster_joins): what do we do here?
+        #   https://github.com/matrix-org/synapse/issues/12814
+        #   https://github.com/matrix-org/synapse/issues/12815
+        #   https://github.com/matrix-org/synapse/issues/13008
 
         return await self.stores.main.get_partial_current_state_deltas(
             prev_stream_id, max_stream_id
         )
 
+    @trace
     async def get_current_state(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
@@ -484,9 +519,21 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
         """Get current hosts in room based on current state."""
 
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_current_hosts_in_room(room_id)
+
+    async def get_users_in_room_with_profiles(
+        self, room_id: str
+    ) -> Dict[str, ProfileInfo]:
+        """
+        Get the current users in the room with their profiles.
+        If the room is currently partial-stated, this will block until the room has
+        full state.
+        """
+        await self._partial_state_room_tracker.await_full_state(room_id)
+
+        return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a78d68a9d7..b394a6658b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -38,7 +39,7 @@ from typing import (
 )
 
 import attr
-from prometheus_client import Histogram
+from prometheus_client import Counter, Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
@@ -75,7 +76,8 @@ perf_logger = logging.getLogger("synapse.storage.TIME")
 sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
 
 sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+sql_txn_count = Counter("synapse_storage_transaction_time_count", "sec", ["desc"])
+sql_txn_duration = Counter("synapse_storage_transaction_time_sum", "sec", ["desc"])
 
 
 # Unique indexes which have been added in background updates. Maps from table name
@@ -92,6 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "event_search": "event_search_event_id_idx",
     "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
     "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
+    "event_push_summary": "event_push_summary_unique_index",
 }
 
 
@@ -167,6 +170,7 @@ class LoggingDatabaseConnection:
         *,
         txn_name: Optional[str] = None,
         after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
         exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
@@ -177,6 +181,7 @@ class LoggingDatabaseConnection:
             name=txn_name,
             database_engine=self.engine,
             after_callbacks=after_callbacks,
+            async_after_callbacks=async_after_callbacks,
             exception_callbacks=exception_callbacks,
         )
 
@@ -208,6 +213,9 @@ class LoggingDatabaseConnection:
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+    Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -226,6 +234,10 @@ class LoggingTransaction:
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
+        async_after_callbacks: A list that asynchronous callbacks will be appended
+            to by `async_call_after` which should run, before after_callbacks, on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
         exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
@@ -237,6 +249,7 @@ class LoggingTransaction:
         "name",
         "database_engine",
         "after_callbacks",
+        "async_after_callbacks",
         "exception_callbacks",
     ]
 
@@ -246,12 +259,14 @@ class LoggingTransaction:
         name: str,
         database_engine: BaseDatabaseEngine,
         after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
         exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
         self.txn = txn
         self.name = name
         self.database_engine = database_engine
         self.after_callbacks = after_callbacks
+        self.async_after_callbacks = async_after_callbacks
         self.exception_callbacks = exception_callbacks
 
     def call_after(
@@ -276,6 +291,28 @@ class LoggingTransaction:
         # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
         self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
+    def async_call_after(
+        self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
+        """Call the given asynchronous callback on the main twisted thread after
+        the transaction has finished (but before those added in `call_after`).
+
+        Mostly used to invalidate remote caches after transactions.
+
+        Note that transactions may be retried a few times if they encounter database
+        errors such as serialization failures. Callbacks given to `async_call_after`
+        will accumulate across transaction attempts and will _all_ be called once a
+        transaction attempt succeeds, regardless of whether previous transaction
+        attempts failed. Otherwise, if all transaction attempts fail, all
+        `call_on_exception` callbacks will be run instead.
+        """
+        # if self.async_after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.async_after_callbacks is not None
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.async_after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
+
     def call_on_exception(
         self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
     ) -> None:
@@ -365,10 +402,11 @@ class LoggingTransaction:
         *args: P.args,
         **kwargs: P.kwargs,
     ) -> R:
-        sql = self._make_sql_one_line(sql)
+        # Generate a one-line version of the SQL to better log it.
+        one_line_sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+        sql_logger.debug("[SQL] {%s} %s", self.name, one_line_sql)
 
         sql = self.database_engine.convert_param_style(sql)
         if args:
@@ -388,7 +426,7 @@ class LoggingTransaction:
                 "db.query",
                 tags={
                     opentracing.tags.DATABASE_TYPE: "sql",
-                    opentracing.tags.DATABASE_STATEMENT: sql,
+                    opentracing.tags.DATABASE_STATEMENT: one_line_sql,
                 },
             ):
                 return func(sql, *args, **kwargs)
@@ -572,6 +610,7 @@ class DatabasePool:
         conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
+        async_after_callbacks: List[_AsyncCallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
         func: Callable[Concatenate[LoggingTransaction, P], R],
         *args: P.args,
@@ -595,6 +634,7 @@ class DatabasePool:
             conn
             desc
             after_callbacks
+            async_after_callbacks
             exception_callbacks
             func
             *args
@@ -657,6 +697,7 @@ class DatabasePool:
                 cursor = conn.cursor(
                     txn_name=name,
                     after_callbacks=after_callbacks,
+                    async_after_callbacks=async_after_callbacks,
                     exception_callbacks=exception_callbacks,
                 )
                 try:
@@ -755,7 +796,8 @@ class DatabasePool:
 
             self._current_txn_total_time += duration
             self._txn_perf_counters.update(desc, duration)
-            sql_txn_timer.labels(desc).observe(duration)
+            sql_txn_count.labels(desc).inc(1)
+            sql_txn_duration.labels(desc).inc(duration)
 
     async def runInteraction(
         self,
@@ -796,6 +838,7 @@ class DatabasePool:
 
         async def _runInteraction() -> R:
             after_callbacks: List[_CallbackListEntry] = []
+            async_after_callbacks: List[_AsyncCallbackListEntry] = []
             exception_callbacks: List[_CallbackListEntry] = []
 
             if not current_context():
@@ -807,6 +850,7 @@ class DatabasePool:
                         self.new_transaction,
                         desc,
                         after_callbacks,
+                        async_after_callbacks,
                         exception_callbacks,
                         func,
                         *args,
@@ -815,13 +859,17 @@ class DatabasePool:
                         **kwargs,
                     )
 
+                # We order these assuming that async functions call out to external
+                # systems (e.g. to invalidate a cache) and the sync functions make these
+                # changes on any local in-memory caches/similar, and thus must be second.
+                for async_callback, async_args, async_kwargs in async_after_callbacks:
+                    await async_callback(*async_args, **async_kwargs)
                 for after_callback, after_args, after_kwargs in after_callbacks:
                     after_callback(*after_args, **after_kwargs)
-
                 return cast(R, result)
             except Exception:
-                for after_callback, after_args, after_kwargs in exception_callbacks:
-                    after_callback(*after_args, **after_kwargs)
+                for exception_callback, after_args, after_kwargs in exception_callbacks:
+                    exception_callback(*after_args, **after_kwargs)
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 11d9d16c19..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -45,7 +45,6 @@ from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
-from .group_server import GroupServerStore
 from .keys import KeyStore
 from .lock import LockStore
 from .media_repository import MediaRepositoryStore
@@ -88,7 +87,6 @@ class DataStore(
     RoomStore,
     RoomBatchStore,
     RegistrationStore,
-    StreamWorkerStore,
     ProfileStore,
     PresenceStore,
     TransactionWorkerStore,
@@ -105,19 +103,20 @@ class DataStore(
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
+    EventPushActionsStore,
+    ServerMetricsStore,
     ReceiptsStore,
     EndToEndKeyStore,
     EndToEndRoomKeyStore,
     SearchStore,
     TagsStore,
     AccountDataStore,
-    EventPushActionsStore,
+    StreamWorkerStore,
     OpenIdStore,
     ClientIpWorkerStore,
     DeviceStore,
     DeviceInboxStore,
     UserDirectoryStore,
-    GroupServerStore,
     UserErasureStore,
     MonthlyActiveUsersWorkerStore,
     StatsStore,
@@ -126,7 +125,6 @@ class DataStore(
     UIAuthStore,
     EventForwardExtremitiesStore,
     CacheInvalidationWorkerStore,
-    ServerMetricsStore,
     LockStore,
     SessionStore,
 ):
@@ -151,31 +149,6 @@ class DataStore(
             ],
         )
 
-        self._cache_id_gen: Optional[MultiWriterIdGenerator]
-        if isinstance(self.database_engine, PostgresEngine):
-            # We set the `writers` to an empty list here as we don't care about
-            # missing updates over restarts, as we'll not have anything in our
-            # caches to invalidate. (This reduces the amount of writes to the DB
-            # that happen).
-            self._cache_id_gen = MultiWriterIdGenerator(
-                db_conn,
-                database,
-                stream_name="caches",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    (
-                        "cache_invalidation_stream_by_instance",
-                        "instance_name",
-                        "stream_id",
-                    )
-                ],
-                sequence_name="cache_invalidation_stream_seq",
-                writers=[],
-            )
-
-        else:
-            self._cache_id_gen = None
-
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
@@ -197,6 +170,7 @@ class DataStore(
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
     def get_device_stream_token(self) -> int:
+        # TODO: shouldn't this be moved to `DeviceWorkerStore`?
         return self._device_list_id_gen.get_current_token()
 
     async def get_users(self) -> List[JsonDict]:
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room, (user_id,)
         )
         self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_push_rules_enabled_for_user, (user_id,)
-        )
         # This user might be contained in the ignored_by cache for other users,
         # so we have to invalidate it all.
         self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
             device_list_summary=DeviceListUpdates(),
         )
 
-    async def set_appservice_last_pos(self, pos: int) -> None:
-        def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
-            )
+    async def get_appservice_last_pos(self) -> int:
+        """
+        Get the last stream ordering position for the appservice process.
+        """
 
-        await self.db_pool.runInteraction(
-            "set_appservice_last_pos", set_appservice_last_pos_txn
+        return await self.db_pool.simple_select_one_onecol(
+            table="appservice_stream_position",
+            retcol="stream_ordering",
+            keyvalues={},
+            desc="get_appservice_last_pos",
         )
 
-    async def get_new_events_for_appservice(
-        self, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
-        """Get all new events for an appservice"""
-
-        def get_new_events_for_appservice_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
-            sql = (
-                "SELECT e.stream_ordering, e.event_id"
-                " FROM events AS e"
-                " WHERE"
-                " (SELECT stream_ordering FROM appservice_stream_position)"
-                "     < e.stream_ordering"
-                " AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-                " LIMIT ?"
-            )
-
-            txn.execute(sql, (current_id, limit))
-            rows = txn.fetchall()
-
-            upper_bound = current_id
-            if len(rows) == limit:
-                upper_bound = rows[-1][0]
-
-            return upper_bound, [row[1] for row in rows]
+    async def set_appservice_last_pos(self, pos: int) -> None:
+        """
+        Set the last stream ordering position for the appservice process.
+        """
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
-            "get_new_events_for_appservice", get_new_events_for_appservice_txn
+        await self.db_pool.simple_update_one(
+            table="appservice_stream_position",
+            keyvalues={},
+            updatevalues={"stream_ordering": pos},
+            desc="set_appservice_last_pos",
         )
 
-        events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
-        return upper_bound, events
-
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             psql_only=True,  # The table is only on postgres DBs.
         )
 
+        self._cache_id_gen: Optional[MultiWriterIdGenerator]
+        if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
+                sequence_name="cache_invalidation_stream_seq",
+                writers=[],
+            )
+
+        else:
+            self._cache_id_gen = None
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -193,7 +219,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         relates_to: Optional[str],
         backfilled: bool,
     ) -> None:
-        self._invalidate_get_event_cache(event_id)
+        # This invalidates any local in-memory cached event objects, the original
+        # process triggering the invalidation is responsible for clearing any external
+        # cached objects.
+        self._invalidate_local_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +237,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
-            self._invalidate_get_event_cache(redacts)
+            self._invalidate_local_get_event_cache(redacts)
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
             self.get_relations_for_event.invalidate((redacts,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            self.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 599b418383..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             (user_id, device_id), None
         )
 
-        set_tag("last_deleted_stream_id", last_deleted_stream_id)
+        set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
 
         if last_deleted_stream_id:
             has_changed = self._device_inbox_stream_cache.has_entity_changed(
@@ -834,8 +834,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-    REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
-    REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
     REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
 
     def __init__(
@@ -857,15 +855,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
-        # Used to be a background update that deletes all device_inboxes for deleted
-        # devices.
-        self.db_pool.updates.register_noop_background_update(
-            self.REMOVE_DELETED_DEVICES
-        )
-        # Used to be a background update that deletes all device_inboxes for hidden
-        # devices.
-        self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
-
         self.db_pool.updates.register_background_update_handler(
             self.REMOVE_DEAD_DEVICES_FROM_INBOX,
             self._remove_dead_devices_from_device_inbox,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 71e7863dd8..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import Literal
+
 from synapse.api.constants import EduTypes
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -44,6 +46,8 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(SQLBaseStore):
+class DeviceWorkerStore(EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        device_list_max = self._device_list_id_gen.get_current_token()
+        # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
+        # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
+        device_list_max = self._device_list_id_gen.get_current_token()  # type: ignore[attr-defined]
         device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
             db_conn,
             "device_lists_stream",
@@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
         # following this stream later.
         last_processed_stream_id = from_stream_id
 
-        query_map = {}
-        cross_signing_keys_by_user = {}
+        # A map of (user ID, device ID) to (stream ID, context).
+        query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
+        cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
         for user_id, device_id, update_stream_id, update_context in updates:
             # Calculate the remaining length budget.
             # Note that, for now, each entry in `cross_signing_keys_by_user`
@@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
             txn=txn,
             table="device_lists_outbound_last_success",
             key_names=("destination", "user_id"),
-            key_values=((destination, user_id) for user_id, _ in rows),
+            key_values=[(destination, user_id) for user_id, _ in rows],
             value_names=("stream_id",),
             value_values=((stream_id,) for _, stream_id in rows),
         )
@@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
             The new stream ID.
         """
 
-        async with self._device_list_id_gen.get_next() as stream_id:
+        # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
+        #  than DeviceWorkerStore?
+        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -660,7 +669,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @trace
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, str]]
+        self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
@@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
         } - users_needing_resync
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
-        results = {}
+        results: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_id in query_list:
             if user_id not in user_ids_in_cache:
                 continue
@@ -697,8 +706,8 @@ class DeviceWorkerStore(SQLBaseStore):
             else:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
 
-        set_tag("in_cache", results)
-        set_tag("not_in_cache", user_ids_not_in_cache)
+        set_tag("in_cache", str(results))
+        set_tag("not_in_cache", str(user_ids_not_in_cache))
 
         return user_ids_not_in_cache, results
 
@@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
     def get_cached_device_list_changes(
         self,
         from_key: int,
-    ) -> Optional[Set[str]]:
+    ) -> Optional[List[str]]:
         """Get set of users whose devices have changed since `from_key`, or None
         if that information is not in our cache.
         """
@@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
-        user_ids: Optional[Iterable[str]] = None,
+        user_ids: Optional[Collection[str]] = None,
         to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
@@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
+        user_ids_to_check: Optional[Collection[str]]
         if user_ids is None:
             # Get set of all users that have had device list changes since 'from_key'
             user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
@@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
             return set()
 
         def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
-            changes = set()
+            changes: Set[str] = set()
 
             stream_id_where_clause = "stream_id > ?"
             sql_args = [from_key]
@@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
             """
 
             # Query device changes with a batch of users at a time
+            # Assertion for mypy's benefit; see also
+            # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
+            assert user_ids_to_check is not None
             for chunk in batch_iter(user_ids_to_check, 100):
                 clause, args = make_in_list_sql_clause(
                     txn.database_engine, "user_id", chunk
@@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def _get_all_device_list_changes_for_remotes(txn):
+        def _get_all_device_list_changes_for_remotes(
+            txn: Cursor,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
@@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="get_device_list_last_stream_id_for_remotes",
         )
 
-        results = {user_id: None for user_id in user_ids}
+        results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
         results.update({row["user_id"]: row["stream_id"] for row in rows})
 
         return results
@@ -1193,6 +1208,65 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return devices
 
+    @cached()
+    async def _get_min_device_lists_changes_in_room(self) -> int:
+        """Returns the minimum stream ID that we have entries for
+        `device_lists_changes_in_room`
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="device_lists_changes_in_room",
+            keyvalues={},
+            retcol="COALESCE(MIN(stream_id), 0)",
+            desc="get_min_device_lists_changes_in_room",
+        )
+
+    async def get_device_list_changes_in_rooms(
+        self, room_ids: Collection[str], from_id: int
+    ) -> Optional[Set[str]]:
+        """Return the set of users whose devices have changed in the given rooms
+        since the given stream ID.
+
+        Returns None if the given stream ID is too old.
+        """
+
+        if not room_ids:
+            return set()
+
+        min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+        if min_stream_id > from_id:
+            return None
+
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_changes_in_room
+            WHERE {clause} AND stream_id >= ?
+        """
+
+        def _get_device_list_changes_in_rooms_txn(
+            txn: LoggingTransaction,
+            clause: str,
+            args: List[Any],
+        ) -> Set[str]:
+            txn.execute(sql.format(clause=clause), args)
+            return {user_id for user_id, in txn}
+
+        changes = set()
+        for chunk in batch_iter(room_ids, 1000):
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", chunk
+            )
+            args.append(from_id)
+
+            changes |= await self.db_pool.runInteraction(
+                "get_device_list_changes_in_rooms",
+                _get_device_list_changes_in_rooms_txn,
+                clause,
+                args,
+            )
+
+        return changes
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -1240,15 +1314,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             self._remove_duplicate_outbound_pokes,
         )
 
-        # a pair of background updates that were added during the 1.14 release cycle,
-        # but replaced with 58/06dlols_unique_idx.py
-        self.db_pool.updates.register_noop_background_update(
-            "device_lists_outbound_last_success_unique_idx",
-        )
-        self.db_pool.updates.register_noop_background_update(
-            "drop_device_lists_outbound_last_success_non_unique_idx",
-        )
-
     async def _drop_device_list_streams_non_unique_indexes(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -1346,9 +1411,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
-        self.device_id_exists_cache = LruCache(
-            cache_name="device_id_exists", max_size=10000
-        )
+        self.device_id_exists_cache: LruCache[
+            Tuple[str, str], Literal[True]
+        ] = LruCache(cache_name="device_id_exists", max_size=10000)
 
     async def store_device(
         self,
@@ -1660,7 +1725,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 context,
             )
 
-        async with self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(  # type: ignore[attr-defined]
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1713,7 +1778,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         device_ids: Iterable[str],
         hosts: Collection[str],
         stream_ids: List[int],
-        context: Dict[str, str],
+        context: Optional[Dict[str, str]],
     ) -> None:
         for host in hosts:
             txn.call_after(
@@ -1884,7 +1949,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 [],
             )
 
-        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:  # type: ignore[attr-defined]
             return await self.db_pool.runInteraction(
                 "add_device_list_outbound_pokes",
                 add_device_list_outbound_pokes_txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
     cast,
+    overload,
 )
 
 import attr
 from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
 
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             user_devices = devices[user_id]
             results = []
             for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
+                result: JsonDict = {"device_id": device_id}
 
                 keys = device.keys
                 if keys:
@@ -143,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key data.  The key data will be a dict in the same format as the
             DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
         """
-        set_tag("query_list", query_list)
+        set_tag("query_list", str(query_list))
         if not query_list:
             return {}
 
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
                 r = device_info.keys
+                if r is None:
+                    continue
+
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return rv
 
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[True],
+        include_deleted_devices: Literal[True],
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        ...
+
     @trace
     async def get_e2e_device_keys_and_signatures(
         self,
-        query_list: List[Tuple[str, Optional[str]]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
-    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+    ) -> Union[
+        Dict[str, Dict[str, DeviceKeyLookupResult]],
+        Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+    ]:
         """Fetch a list of device keys
 
         Any cross-signatures made on the keys by the owner of the device are also
@@ -383,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
+            set_tag("new_keys", str(new_keys))
             # We are protected from race between lookup and insertion due to
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            row = await self.db_pool.runInteraction(
+            claim_row = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 algorithm,
                 db_autocommit=db_autocommit,
             )
-            if row:
+            if claim_row:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[row[0]] = row[1]
+                device_results[claim_row[0]] = claim_row[1]
                 continue
 
             # No one-time key available, so see if there's a fallback
@@ -1126,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
+            set_tag("device_keys", str(device_keys))
 
             old_key_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(event_ids)
 
+    @trace
+    @tag_args
     async def get_auth_chain_ids(
         self,
         room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b019979350..f4a07de2a3 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,18 +12,92 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+  1. Sending out push to a user's device; and
+  2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 
+from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -79,18 +153,20 @@ class UserPushAction(EmailPushAction):
     profile_tag: str
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
 class NotifCounts:
     """
     The per-user, per-room count of notifications. Used by sync and push.
     """
 
-    notify_count: int
-    unread_count: int
-    highlight_count: int
+    notify_count: int = 0
+    unread_count: int = 0
+    highlight_count: int = 0
 
 
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+    actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -119,7 +195,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsWorkerStore(SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -140,23 +216,30 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
 
-        self._rotate_delay = 3
         self._rotate_count = 10000
         self._doing_notif_rotation = False
         if hs.config.worker.run_background_tasks:
             self._rotate_notif_loop = self._clock.looping_call(
-                self._rotate_notifs, 30 * 60 * 1000
+                self._rotate_notifs, 30 * 1000
             )
 
-    @cached(num_args=3, tree=True, max_entries=5000)
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index",
+            index_name="event_push_summary_unique_index",
+            table="event_push_summary",
+            columns=["user_id", "room_id"],
+            unique=True,
+            replaces_index="event_push_summary_user_rm",
+        )
+
+    @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
-        for a given user in a given room after the given read receipt.
+        for a given user in a given room after their latest read receipt.
 
         Note that this function assumes the user to be a current member of the room,
         since it's either called by the sync handler to handle joined room entries, or by
@@ -165,20 +248,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         Args:
             room_id: The room to retrieve the counts in.
             user_id: The user to retrieve the counts for.
-            last_read_event_id: The event associated with the latest read receipt for
-                this user in this room. None if no receipt for this user in this room.
 
         Returns
-            A dict containing the counts mentioned earlier in this docstring,
-            respectively under the keys "notify_count", "highlight_count" and
-            "unread_count".
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
-            last_read_event_id,
         )
 
     def _get_unread_counts_by_receipt_txn(
@@ -186,20 +265,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         txn: LoggingTransaction,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
-        stream_ordering = None
+        # Get the stream ordering of the user's latest receipt in the room.
+        result = self.get_last_receipt_for_user_txn(
+            txn,
+            user_id,
+            room_id,
+            receipt_types=(
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ),
+        )
 
-        if last_read_event_id is not None:
-            stream_ordering = self.get_stream_id_for_event_txn(  # type: ignore[attr-defined]
-                txn,
-                last_read_event_id,
-                allow_none=True,
-            )
+        if result:
+            _, stream_ordering = result
 
-        if stream_ordering is None:
-            # Either last_read_event_id is None, or it's an event we don't have (e.g.
-            # because it's been purged), in which case retrieve the stream ordering for
+        else:
+            # If the user has no receipts in the room, retrieve the stream ordering for
             # the latest membership event from this user in this room (which we assume is
             # a join).
             event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -209,57 +291,159 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 retcol="event_id",
             )
 
-            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)  # type: ignore[attr-defined]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        receipt_stream_ordering: int,
     ) -> NotifCounts:
-        sql = (
-            "SELECT"
-            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
-            " FROM event_push_actions ea"
-            " WHERE user_id = ?"
-            "   AND room_id = ?"
-            "   AND stream_ordering > ?"
-        )
+        """Get the number of unread messages for a user/room that have happened
+        since the given stream ordering.
 
-        txn.execute(sql, (user_id, room_id, stream_ordering))
-        row = txn.fetchone()
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            receipt_stream_ordering: The stream ordering of the user's latest
+                receipt in the room. If there are no receipts, the stream ordering
+                of the user's join event.
 
-        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+        Returns
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
+        """
 
-        if row:
-            (notif_count, highlight_count, unread_count) = row
+        counts = NotifCounts()
 
+        # First we pull the counts from the summary table.
+        #
+        # We check that `last_receipt_stream_ordering` matches the stream
+        # ordering given. If it doesn't match then a new read receipt has arrived and
+        # we haven't yet updated the counts in `event_push_summary` to reflect
+        # that; in that case we simply ignore `event_push_summary` counts
+        # and do a manual count of all of the rows in the `event_push_actions` table
+        # for this user/room.
+        #
+        # If `last_receipt_stream_ordering` is null then that means it's up to
+        # date (as the row was written by an older version of Synapse that
+        # updated `event_push_summary` synchronously when persisting a new read
+        # receipt).
         txn.execute(
             """
-                SELECT notif_count, unread_count FROM event_push_summary
-                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+                SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+                FROM event_push_summary
+                WHERE room_id = ? AND user_id = ?
+                AND (
+                    (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
+                    OR last_receipt_stream_ordering = ?
+                )
             """,
-            (room_id, user_id, stream_ordering),
+            (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
         )
         row = txn.fetchone()
 
+        summary_stream_ordering = 0
+        if row:
+            summary_stream_ordering = row[0]
+            counts.notify_count += row[1]
+            counts.unread_count += row[2]
+
+        # Next we need to count highlights, which aren't summarised
+        sql = """
+            SELECT COUNT(*) FROM event_push_actions
+            WHERE user_id = ?
+                AND room_id = ?
+                AND stream_ordering > ?
+                AND highlight = 1
+        """
+        txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
+        row = txn.fetchone()
         if row:
-            notif_count += row[0]
-
-            if row[1] is not None:
-                # The unread_count column of event_push_summary is NULLable, so we need
-                # to make sure we don't try increasing the unread counts if it's NULL
-                # for this row.
-                unread_count += row[1]
-
-        return NotifCounts(
-            notify_count=notif_count,
-            unread_count=unread_count,
-            highlight_count=highlight_count,
+            counts.highlight_count += row[0]
+
+        # Finally we need to count push actions that aren't included in the
+        # summary returned above. This might be due to recent events that haven't
+        # been summarised yet or the summary is out of date due to a recent read
+        # receipt.
+        start_unread_stream_ordering = max(
+            receipt_stream_ordering, summary_stream_ordering
         )
+        notify_count, unread_count = self._get_notif_unread_count_for_user_room(
+            txn, room_id, user_id, start_unread_stream_ordering
+        )
+
+        counts.notify_count += notify_count
+        counts.unread_count += unread_count
+
+        return counts
+
+    def _get_notif_unread_count_for_user_room(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        stream_ordering: int,
+        max_stream_ordering: Optional[int] = None,
+    ) -> Tuple[int, int]:
+        """Returns the notify and unread counts from `event_push_actions` for
+        the given user/room in the given range.
+
+        Does not consult `event_push_summary` table, which may include push
+        actions that have been deleted from `event_push_actions` table.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            stream_ordering: The (exclusive) minimum stream ordering to consider.
+            max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+                If this is not given, then no maximum is applied.
+
+        Return:
+            A tuple of the notif count and unread count in the given range.
+        """
+
+        # If there have been no events in the room since the stream ordering,
+        # there can't be any push actions either.
+        if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
+            return 0, 0
+
+        clause = ""
+        args = [user_id, room_id, stream_ordering]
+        if max_stream_ordering is not None:
+            clause = "AND ea.stream_ordering <= ?"
+            args.append(max_stream_ordering)
+
+            # If the max stream ordering is less than the min stream ordering,
+            # then obviously there are zero push actions in that range.
+            if max_stream_ordering <= stream_ordering:
+                return 0, 0
+
+        sql = f"""
+            SELECT
+               COUNT(CASE WHEN notif = 1 THEN 1 END),
+               COUNT(CASE WHEN unread = 1 THEN 1 END)
+             FROM event_push_actions ea
+             WHERE user_id = ?
+               AND room_id = ?
+               AND ea.stream_ordering > ?
+               {clause}
+        """
+
+        txn.execute(sql, args)
+        row = txn.fetchone()
+
+        if row:
+            return cast(Tuple[int, int], row)
+
+        return 0, 0
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering: int, max_stream_ordering: int
@@ -274,6 +458,31 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
+    def _get_receipts_by_room_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> List[Tuple[str, int]]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ),
+        )
+
+        sql = f"""
+            SELECT room_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+            AND user_id = ?
+            GROUP BY room_id
+        """
+
+        args.extend((user_id,))
+        txn.execute(sql, args)
+        return cast(List[Tuple[str, int]], txn.fetchall())
+
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
         user_id: str,
@@ -296,81 +505,46 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
-        # find rooms that have a read receipt in them and return the next
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
-            # find rooms that have a read receipt in them and return the next
-            # push actions
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_http_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+                FROM event_push_actions AS ep
+                WHERE
+                    ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
         notifs = [
             HttpPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -405,82 +579,50 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
-        # find rooms that have a read receipt in them and return the most recent
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  ep.highlight, e.received_ts"
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_email_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight, e.received_ts"
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
-                received_ts=row[5],
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
+                received_ts=received_ts,
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -526,7 +668,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     async def add_push_actions_to_staging(
         self,
         event_id: str,
-        user_id_actions: Dict[str, List[Union[dict, str]]],
+        user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
         count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
@@ -543,7 +685,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
-            user_id: str, actions: List[Union[dict, str]]
+            user_id: str, actions: Collection[Union[Mapping, str]]
         ) -> Tuple[str, str, str, int, int, int]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
@@ -556,26 +698,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 int(count_as_unread),  # unread column
             )
 
-        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
-            # We don't use simple_insert_many here to avoid the overhead
-            # of generating lists of dicts.
-
-            sql = """
-                INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            txn.execute_batch(
-                sql,
-                (
-                    _gen_entry(user_id, actions)
-                    for user_id, actions in user_id_actions.items()
-                ),
-            )
-
-        return await self.db_pool.runInteraction(
-            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        await self.db_pool.simple_insert_many(
+            "event_push_actions_staging",
+            keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"),
+            values=[
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.items()
+            ],
+            desc="add_push_actions_to_staging",
         )
 
     async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -689,12 +819,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # [10, <none>, 20], we should treat this as being equivalent to
         # [10, 10, 20].
         #
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering <= ?"
-            " ORDER BY stream_ordering DESC"
-            " LIMIT 1"
-        )
+        sql = """
+            SELECT received_ts FROM events
+            WHERE stream_ordering <= ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
+        """
 
         while range_end - range_start > 0:
             middle = (range_end + range_start) // 2
@@ -722,14 +852,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self, stream_ordering: int
     ) -> Optional[int]:
         def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ? AND notif = 1"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
+            sql = """
+                SELECT e.received_ts
+                FROM event_push_actions AS ep
+                JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+                WHERE ep.stream_ordering > ? AND notif = 1
+                ORDER BY ep.stream_ordering ASC
+                LIMIT 1
+            """
             txn.execute(sql, (stream_ordering,))
             return cast(Optional[Tuple[int]], txn.fetchone())
 
@@ -745,6 +875,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self._doing_notif_rotation = True
 
         try:
+            # First we recalculate push summaries and delete stale push actions
+            # for rooms/users with new receipts.
+            while True:
+                logger.debug("Handling new receipts")
+
+                caught_up = await self.db_pool.runInteraction(
+                    "_handle_new_receipts_for_notifs_txn",
+                    self._handle_new_receipts_for_notifs_txn,
+                )
+                if caught_up:
+                    break
+
+            # Then we update the event push summaries for any new events
             while True:
                 logger.info("Rotating notifications")
 
@@ -753,15 +896,129 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 )
                 if caught_up:
                     break
-                await self.hs.get_clock().sleep(self._rotate_delay)
+
+            # Finally we clear out old event push actions.
+            await self._remove_old_push_actions_that_have_rotated()
         finally:
             self._doing_notif_rotation = False
 
+    def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
+        """Check for new read receipts and delete from event push actions.
+
+        Any push actions which predate the user's most recent read receipt are
+        now redundant, so we can remove them from `event_push_actions` and
+        update `event_push_summary`.
+
+        Returns true if all new receipts have been processed.
+        """
+
+        limit = 100
+
+        # The (inclusive) receipt stream ID that was previously processed..
+        min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_last_receipt_stream_id",
+            keyvalues={},
+            retcol="stream_id",
+        )
+
+        max_receipts_stream_id = self._receipts_id_gen.get_current_token()
+
+        # The (inclusive) event stream ordering that was previously summarised.
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        sql = """
+            SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+            FROM receipts_linearized AS r
+            INNER JOIN events AS e USING (event_id)
+            WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
+            ORDER BY r.stream_id ASC
+            LIMIT ?
+        """
+
+        # We only want local users, so we add a dodgy filter to the above query
+        # and recheck it below.
+        user_filter = "%:" + self.hs.hostname
+
+        txn.execute(
+            sql,
+            (
+                min_receipts_stream_id,
+                max_receipts_stream_id,
+                user_filter,
+                limit,
+            ),
+        )
+        rows = txn.fetchall()
+
+        # For each new read receipt we delete push actions from before it and
+        # recalculate the summary.
+        for _, room_id, user_id, stream_ordering in rows:
+            # Only handle our own read receipts.
+            if not self.hs.is_mine_id(user_id):
+                continue
+
+            txn.execute(
+                """
+                DELETE FROM event_push_actions
+                WHERE room_id = ?
+                    AND user_id = ?
+                    AND stream_ordering <= ?
+                    AND highlight = 0
+                """,
+                (room_id, user_id, stream_ordering),
+            )
+
+            # Fetch the notification counts between the stream ordering of the
+            # latest receipt and what was previously summarised.
+            notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+                txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+            )
+
+            # Replace the previous summary with the new counts.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="event_push_summary",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                values={
+                    "notif_count": notif_count,
+                    "unread_count": unread_count,
+                    "stream_ordering": old_rotate_stream_ordering,
+                    "last_receipt_stream_ordering": stream_ordering,
+                },
+            )
+
+        # We always update `event_push_summary_last_receipt_stream_id` to
+        # ensure that we don't rescan the same receipts for remote users.
+
+        upper_limit = max_receipts_stream_id
+        if len(rows) >= limit:
+            # If we pulled out a limited number of rows we only update the
+            # position to the last receipt we processed, so we continue
+            # processing the rest next iteration.
+            upper_limit = rows[-1][0]
+
+        self.db_pool.simple_update_txn(
+            txn,
+            table="event_push_summary_last_receipt_stream_id",
+            keyvalues={},
+            updatevalues={"stream_id": upper_limit},
+        )
+
+        return len(rows) < limit
+
     def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
-        """Archives older notifications into event_push_summary. Returns whether
-        the archiving process has caught up or not.
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Returns whether the archiving process has caught up or not.
         """
 
+        # The (inclusive) event stream ordering that was previously summarised.
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -776,50 +1033,64 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
             ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-        """,
+            """,
             (old_rotate_stream_ordering, self._rotate_count),
         )
         stream_row = txn.fetchone()
         if stream_row:
             (offset_stream_ordering,) = stream_row
-            assert self.stream_ordering_day_ago is not None
+
+            # We need to bound by the current token to ensure that we handle
+            # out-of-order writes correctly.
             rotate_to_stream_ordering = min(
-                self.stream_ordering_day_ago, offset_stream_ordering
+                offset_stream_ordering, self._stream_id_gen.get_current_token()
             )
-            caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+            caught_up = False
         else:
-            rotate_to_stream_ordering = self.stream_ordering_day_ago
+            rotate_to_stream_ordering = self._stream_id_gen.get_current_token()
             caught_up = True
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
-        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+        self._rotate_notifs_before_txn(
+            txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+        )
 
-        # We have caught up iff we were limited by `stream_ordering_day_ago`
         return caught_up
 
     def _rotate_notifs_before_txn(
-        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        old_rotate_stream_ordering: int,
+        rotate_to_stream_ordering: int,
     ) -> None:
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+        rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+        table.
+
+        Args:
+            txn: The database transaction.
+            old_rotate_stream_ordering: The previous maximum event stream ordering.
+            rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+        """
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
                 coalesce(old.%s, 0) + upd.cnt,
-                upd.stream_ordering,
-                old.user_id
+                upd.stream_ordering
             FROM (
                 SELECT user_id, room_id, count(*) as cnt,
-                    max(stream_ordering) as stream_ordering
-                FROM event_push_actions
-                WHERE ? <= stream_ordering AND stream_ordering < ?
-                    AND highlight = 0
+                    max(ea.stream_ordering) as stream_ordering
+                FROM event_push_actions AS ea
+                LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+                WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
+                    AND (
+                        old.last_receipt_stream_ordering IS NULL
+                        OR old.last_receipt_stream_ordering < ea.stream_ordering
+                    )
                     AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
@@ -842,7 +1113,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             summaries[(row[0], row[1])] = _EventPushSummary(
                 unread_count=row[2],
                 stream_ordering=row[3],
-                old_user_id=row[4],
                 notif_count=0,
             )
 
@@ -863,115 +1133,93 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 summaries[(row[0], row[1])] = _EventPushSummary(
                     unread_count=0,
                     stream_ordering=row[3],
-                    old_user_id=row[4],
                     notif_count=row[2],
                 )
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
 
-        # If the `old.user_id` above is NULL then we know there isn't already an
-        # entry in the table, so we simply insert it. Otherwise we update the
-        # existing table.
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",
-            keys=(
-                "user_id",
-                "room_id",
-                "notif_count",
-                "unread_count",
-                "stream_ordering",
-            ),
-            values=[
+            key_names=("user_id", "room_id"),
+            key_values=[(user_id, room_id) for user_id, room_id in summaries],
+            value_names=("notif_count", "unread_count", "stream_ordering"),
+            value_values=[
                 (
-                    user_id,
-                    room_id,
                     summary.notif_count,
                     summary.unread_count,
                     summary.stream_ordering,
                 )
-                for ((user_id, room_id), summary) in summaries.items()
-                if summary.old_user_id is None
+                for summary in summaries.values()
             ],
         )
 
-        txn.execute_batch(
-            """
-                UPDATE event_push_summary
-                SET notif_count = ?, unread_count = ?, stream_ordering = ?
-                WHERE user_id = ? AND room_id = ?
-            """,
-            (
-                (
-                    summary.notif_count,
-                    summary.unread_count,
-                    summary.stream_ordering,
-                    user_id,
-                    room_id,
-                )
-                for ((user_id, room_id), summary) in summaries.items()
-                if summary.old_user_id is not None
-            ),
-        )
-
-        txn.execute(
-            "DELETE FROM event_push_actions"
-            " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
-            (old_rotate_stream_ordering, rotate_to_stream_ordering),
-        )
-
-        logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
-
         txn.execute(
             "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
             (rotate_to_stream_ordering,),
         )
 
-    def _remove_old_push_actions_before_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
-    ) -> None:
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
+    async def _remove_old_push_actions_that_have_rotated(self) -> None:
+        """Clear out old push actions that have been summarised."""
 
-        Args:
-            txn: The transaction
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate,
-            (room_id, user_id),
+        # We want to clear out anything that is older than a day that *has* already
+        # been rotated.
+        rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
         )
 
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        max_stream_ordering_to_delete = min(
+            rotated_upto_stream_ordering, self.stream_ordering_day_ago
         )
 
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
+        def remove_old_push_actions_that_have_rotated_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            # We don't want to clear out too much at a time, so we bound our
+            # deletes.
+            batch_size = self._rotate_count
+
+            txn.execute(
+                """
+                SELECT stream_ordering FROM event_push_actions
+                WHERE stream_ordering <= ? AND highlight = 0
+                ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+                """,
+                (
+                    max_stream_ordering_to_delete,
+                    batch_size,
+                ),
+            )
+            stream_row = txn.fetchone()
+
+            if stream_row:
+                (stream_ordering,) = stream_row
+            else:
+                stream_ordering = max_stream_ordering_to_delete
+
+            # We need to use a inclusive bound here to handle the case where a
+            # single stream ordering has more than `batch_size` rows.
+            txn.execute(
+                """
+                DELETE FROM event_push_actions
+                WHERE stream_ordering <= ? AND highlight = 0
+                """,
+                (stream_ordering,),
+            )
+
+            logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+            return txn.rowcount < batch_size
+
+        while True:
+            done = await self.db_pool.runInteraction(
+                "_remove_old_push_actions_that_have_rotated",
+                remove_old_push_actions_that_have_rotated_txn,
+            )
+            if done:
+                break
 
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
@@ -1000,6 +1248,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             where_clause="highlight=1",
         )
 
+        # Add index to make deleting old push actions faster.
+        self.db_pool.updates.register_background_index_update(
+            "event_push_actions_stream_highlight_index",
+            index_name="event_push_actions_stream_highlight_index",
+            table="event_push_actions",
+            columns=["highlight", "stream_ordering"],
+            where_clause="highlight=0",
+            psql_only=True,
+        )
+
     async def get_push_actions_for_user(
         self,
         user_id: str,
@@ -1024,16 +1282,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
             # NB. This assumes event_ids are globally unique since
             # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
+            sql = """
+                SELECT epa.event_id, epa.room_id,
+                    epa.stream_ordering, epa.topological_ordering,
+                    epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+                FROM event_push_actions epa, events e
+                WHERE epa.event_id = e.event_id
+                    AND epa.user_id = ? %s
+                    AND epa.notif = 1
+                ORDER BY epa.stream_ordering DESC
+                LIMIT ?
+            """ % (
+                before_clause,
             )
             txn.execute(sql, args)
             return cast(
@@ -1056,7 +1316,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         ]
 
 
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
     for action in actions:
         if not isinstance(action, dict):
             continue
@@ -1075,5 +1335,4 @@ class _EventPushSummary:
 
     unread_count: int
     stream_ordering: int
-    old_user_id: str
     notif_count: int
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 17e35cf63e..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,6 +16,7 @@
 import itertools
 import logging
 from collections import OrderedDict
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -35,9 +36,11 @@ from prometheus_client import Counter
 
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -46,7 +49,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import JsonDict, StateMap, get_domain_from_id
@@ -69,6 +72,24 @@ event_counter = Counter(
 )
 
 
+class PartialStateConflictError(SynapseError):
+    """An internal error raised when attempting to persist an event with partial state
+    after the room containing the event has been un-partial stated.
+
+    This error should be handled by recomputing the event context and trying again.
+
+    This error has an HTTP status code so that it can be transported over replication.
+    It should not be exposed to clients.
+    """
+
+    def __init__(self) -> None:
+        super().__init__(
+            HTTPStatus.CONFLICT,
+            msg="Cannot persist partial state event in un-partial stated room",
+            errcode=Codes.UNKNOWN,
+        )
+
+
 @attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -125,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -154,6 +176,10 @@ class PersistEventsStore:
 
         Returns:
             Resolves when the events have been persisted
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
 
         # We want to calculate the stream orderings as late as possible, as
@@ -354,6 +380,9 @@ class PersistEventsStore:
                 For each room, a list of the event ids which are the forward
                 extremities.
 
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         state_delta_for_room = state_delta_for_room or {}
         new_forward_extremities = new_forward_extremities or {}
@@ -980,16 +1009,16 @@ class PersistEventsStore:
         self,
         room_id: str,
         state_delta: DeltaState,
-        stream_id: int,
     ) -> None:
         """Update the current state stored in the datatabase for the given room"""
 
-        await self.db_pool.runInteraction(
-            "update_current_state",
-            self._update_current_state_txn,
-            state_delta_by_room={room_id: state_delta},
-            stream_id=stream_id,
-        )
+        async with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db_pool.runInteraction(
+                "update_current_state",
+                self._update_current_state_txn,
+                state_delta_by_room={room_id: state_delta},
+                stream_id=stream_ordering,
+            )
 
     def _update_current_state_txn(
         self,
@@ -1266,7 +1295,7 @@ class PersistEventsStore:
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1304,6 +1333,10 @@ class PersistEventsStore:
 
         Returns:
             new list, without events which are already in the events table.
+
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
         """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1315,9 +1348,24 @@ class PersistEventsStore:
             event_id: outlier for event_id, outlier in txn
         }
 
+        logger.debug(
+            "_update_outliers_txn: events=%s have_persisted=%s",
+            [ev.event_id for ev, _ in events_and_contexts],
+            have_persisted,
+        )
+
         to_remove = set()
         for event, context in events_and_contexts:
-            if event.event_id not in have_persisted:
+            outlier_persisted = have_persisted.get(event.event_id)
+            logger.debug(
+                "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s",
+                event.event_id,
+                event.internal_metadata.is_outlier(),
+                outlier_persisted,
+            )
+
+            # Ignore events which we haven't persisted at all
+            if outlier_persisted is None:
                 continue
 
             to_remove.add(event)
@@ -1327,7 +1375,6 @@ class PersistEventsStore:
                 # was an outlier or not - what we have is at least as good.
                 continue
 
-            outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
                 # an outlier in the database. We now have some state at that event
@@ -1338,7 +1385,10 @@ class PersistEventsStore:
                 # events down /sync. In general they will be historical events, so that
                 # doesn't matter too much, but that is not always the case.
 
-                logger.info("Updating state for ex-outlier event %s", event.event_id)
+                logger.info(
+                    "_update_outliers_txn: Updating state for ex-outlier event %s",
+                    event.event_id,
+                )
 
                 # insert into event_to_state_groups.
                 try:
@@ -1442,7 +1492,7 @@ class PersistEventsStore:
                     event.sender,
                     "url" in event.content and isinstance(event.content["url"], str),
                     event.get_state_key(),
-                    context.rejected or None,
+                    context.rejected,
                 )
                 for event, context in events_and_contexts
             ),
@@ -1638,13 +1688,13 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill() -> None:
+        async def prefill() -> None:
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set(
+                await self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
                 )
 
-        txn.call_after(prefill)
+        txn.async_call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Invalidate the caches for the redacted event.
@@ -1653,7 +1703,7 @@ class PersistEventsStore:
         _invalidate_caches_for_event.
         """
         assert event.redacts is not None
-        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
@@ -1766,6 +1816,18 @@ class PersistEventsStore:
                 self.store.get_invited_rooms_for_local_user.invalidate,
                 (event.state_key,),
             )
+            txn.call_after(
+                self.store.get_local_users_in_room.invalidate,
+                (event.room_id,),
+            )
+            txn.call_after(
+                self.store.get_number_joined_users_in_room.invalidate,
+                (event.room_id,),
+            )
+            txn.call_after(
+                self.store.get_user_in_room_with_profile.invalidate,
+                (event.room_id, event.state_key),
+            )
 
             # The `_get_membership_from_event_id` is immutable, except for the
             # case where we look up an event *before* persisting it.
@@ -2215,6 +2277,11 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events_and_contexts: Collection[Tuple[EventBase, EventContext]],
     ) -> None:
+        """
+        Raises:
+            PartialStateConflictError: if attempting to persist a partial state event in
+                a room that has been un-partial stated.
+        """
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2239,19 +2306,37 @@ class PersistEventsStore:
         # if we have partial state for these events, record the fact. (This happens
         # here rather than in _store_event_txn because it also needs to happen when
         # we de-outlier an event.)
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="partial_state_events",
-            keys=("room_id", "event_id"),
-            values=[
-                (
-                    event.room_id,
-                    event.event_id,
-                )
-                for event, ctx in events_and_contexts
-                if ctx.partial_state
-            ],
-        )
+        try:
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="partial_state_events",
+                keys=("room_id", "event_id"),
+                values=[
+                    (
+                        event.room_id,
+                        event.event_id,
+                    )
+                    for event, ctx in events_and_contexts
+                    if ctx.partial_state
+                ],
+            )
+        except self.db_pool.engine.module.IntegrityError:
+            logger.info(
+                "Cannot persist events %s in rooms %s: room has been un-partial stated",
+                [
+                    event.event_id
+                    for event, ctx in events_and_contexts
+                    if ctx.partial_state
+                ],
+                list(
+                    {
+                        event.room_id
+                        for event, ctx in events_and_contexts
+                        if ctx.partial_state
+                    }
+                ),
+            )
+            raise PartialStateConflictError()
 
         self.db_pool.simple_upsert_many_txn(
             txn,
@@ -2296,11 +2381,9 @@ class PersistEventsStore:
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_edges",
-            keys=("event_id", "prev_event_id", "room_id", "is_state"),
+            keys=("event_id", "prev_event_id"),
             values=[
-                (ev.event_id, e_id, ev.room_id, False)
-                for ev in events
-                for e_id in ev.prev_event_ids()
+                (ev.event_id, e_id) for ev in events for e_id in ev.prev_event_ids()
             ],
         )
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index d5f0059665..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -64,6 +64,11 @@ class _BackgroundUpdates:
     INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
     REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
 
+    EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
+    EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
+
+    EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _CalculateChainCover:
@@ -177,11 +182,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self._purged_chain_cover_index,
         )
 
-        # The event_thread_relation background update was replaced with the
-        # event_arbitrary_relations one, which handles any relation to avoid
-        # needed to potentially crawl the entire events table in the future.
-        self.db_pool.updates.register_noop_background_update("event_thread_relation")
-
         self.db_pool.updates.register_background_update_handler(
             "event_arbitrary_relations",
             self._event_arbitrary_relations,
@@ -240,6 +240,26 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         ################################################################################
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
+            self._background_drop_invalid_event_edges_rows,
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            _BackgroundUpdates.EVENT_EDGES_REPLACE_INDEX,
+            index_name="event_edges_event_id_prev_event_id_idx",
+            table="event_edges",
+            columns=["event_id", "prev_event_id"],
+            unique=True,
+            # the old index which just covered event_id is now redundant.
+            replaces_index="ev_edges_id",
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+            self._background_events_populate_state_key_rejections,
+        )
+
     async def _background_reindex_fields_sender(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -1290,3 +1310,179 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         )
 
         return 0
+
+    async def _background_drop_invalid_event_edges_rows(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Drop invalid rows from event_edges
+
+        This only runs for postgres. For SQLite, it all happens synchronously.
+
+        Firstly, drop any rows with is_state=True. These may have been added a long time
+        ago, but they are no longer used.
+
+        We also drop rows that do not correspond to entries in `events`, and add a
+        foreign key.
+        """
+
+        last_event_id = progress.get("last_event_id", "")
+
+        def drop_invalid_event_edges_txn(txn: LoggingTransaction) -> bool:
+            """Returns True if we're done."""
+
+            # first we need to find an endpoint.
+            txn.execute(
+                """
+                SELECT event_id FROM event_edges
+                WHERE event_id > ?
+                ORDER BY event_id
+                LIMIT 1 OFFSET ?
+                """,
+                (last_event_id, batch_size),
+            )
+
+            endpoint = None
+            row = txn.fetchone()
+
+            if row:
+                endpoint = row[0]
+
+            where_clause = "ee.event_id > ?"
+            args = [last_event_id]
+            if endpoint:
+                where_clause += " AND ee.event_id <= ?"
+                args.append(endpoint)
+
+            # now delete any that:
+            #   - have is_state=TRUE, or
+            #   - do not correspond to a row in `events`
+            txn.execute(
+                f"""
+                DELETE FROM event_edges
+                WHERE event_id IN (
+                   SELECT ee.event_id
+                   FROM event_edges ee
+                     LEFT JOIN events ev USING (event_id)
+                   WHERE ({where_clause}) AND
+                     (is_state OR ev.event_id IS NULL)
+                )""",
+                args,
+            )
+
+            logger.info(
+                "cleaned up event_edges up to %s: removed %i/%i rows",
+                endpoint,
+                txn.rowcount,
+                batch_size,
+            )
+
+            if endpoint is not None:
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
+                    {"last_event_id": endpoint},
+                )
+                return False
+
+            # if that was the final batch, we validate the foreign key.
+            #
+            # The constraint should have been in place and enforced for new rows since
+            # before we started deleting invalid rows, so there's no chance for any
+            # invalid rows to have snuck in the meantime. In other words, this really
+            # ought to succeed.
+            logger.info("cleaned up event_edges; enabling foreign key")
+            txn.execute(
+                "ALTER TABLE event_edges VALIDATE CONSTRAINT event_edges_event_id_fkey"
+            )
+            return True
+
+        done = await self.db_pool.runInteraction(
+            desc="drop_invalid_event_edges", func=drop_invalid_event_edges_txn
+        )
+
+        if done:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS
+            )
+
+        return batch_size
+
+    async def _background_events_populate_state_key_rejections(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Back-populate `events.state_key` and `events.rejection_reason"""
+
+        min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+        max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+        def _populate_txn(txn: LoggingTransaction) -> bool:
+            """Returns True if we're done."""
+
+            # first we need to find an endpoint.
+            # we need to find the final row in the batch of batch_size, which means
+            # we need to skip over (batch_size-1) rows and get the next row.
+            txn.execute(
+                """
+                SELECT stream_ordering FROM events
+                WHERE stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering
+                LIMIT 1 OFFSET ?
+                """,
+                (
+                    min_stream_ordering_exclusive,
+                    max_stream_ordering_inclusive,
+                    batch_size - 1,
+                ),
+            )
+
+            endpoint = None
+            row = txn.fetchone()
+            if row:
+                endpoint = row[0]
+
+            where_clause = "stream_ordering > ?"
+            args = [min_stream_ordering_exclusive]
+            if endpoint:
+                where_clause += " AND stream_ordering <= ?"
+                args.append(endpoint)
+
+            # now do the updates.
+            txn.execute(
+                f"""
+                UPDATE events
+                SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+                    rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+                WHERE ({where_clause})
+                """,
+                args,
+            )
+
+            logger.info(
+                "populated new `events` columns up to %s/%i: updated %i rows",
+                endpoint,
+                max_stream_ordering_inclusive,
+                txn.rowcount,
+            )
+
+            if endpoint is None:
+                # we're done
+                return True
+
+            progress["min_stream_ordering_exclusive"] = endpoint
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+                progress,
+            )
+            return False
+
+        done = await self.db_pool.runInteraction(
+            desc="events_populate_state_key_rejections", func=_populate_txn
+        )
+
+        if done:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+            )
+
+        return batch_size
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..9b997c304d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -79,7 +80,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -238,7 +239,9 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+        self._get_event_cache: AsyncLruCache[
+            Tuple[str], EventCacheEntry
+        ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -292,25 +295,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def get_received_ts(self, event_id: str) -> Optional[int]:
-        """Get received_ts (when it was persisted) for the event.
-
-        Raises an exception for unknown events.
-
-        Args:
-            event_id: The event ID to query.
-
-        Returns:
-            Timestamp in milliseconds, or None for events that were persisted
-            before received_ts was implemented.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events",
-            keyvalues={"event_id": event_id},
-            retcol="received_ts",
-            desc="get_received_ts",
-        )
-
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
@@ -447,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -617,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        # Shortcut: check if we have any events in the *in memory* cache - this function
+        # may be called repeatedly for the same event so at this point we cannot reach
+        # out to any external cache for performance reasons. The external cache is
+        # checked later on in the `get_missing_events_from_cache_or_db` function below.
+        event_entry_map = self._get_events_from_local_cache(
             event_ids,
         )
 
@@ -649,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         if missing_events_ids:
 
-            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+            async def get_missing_events_from_cache_or_db() -> Dict[
+                str, EventCacheEntry
+            ]:
                 """Fetches the events in `missing_event_ids` from the database.
 
                 Also creates entries in `self._current_event_fetches` to allow
@@ -674,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
                 # the events have been redacted, and if so pulling the redaction event
                 # out of the database to check it.
                 #
+                missing_events = {}
                 try:
-                    missing_events = await self._get_events_from_db(
+                    # Try to fetch from any external cache. We already checked the
+                    # in-memory cache above.
+                    missing_events = await self._get_events_from_external_cache(
                         missing_events_ids,
                     )
+                    # Now actually fetch any remaining events from the DB
+                    db_missing_events = await self._get_events_from_db(
+                        missing_events_ids - missing_events.keys(),
+                    )
+                    missing_events.update(db_missing_events)
                 except Exception as e:
                     with PreserveLoggingContext():
                         fetching_deferred.errback(e)
@@ -696,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
             # cancellations, since multiple `_get_events_from_cache_or_db` calls can
             # reuse the same fetch.
             missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
-                get_missing_events_from_db()
+                get_missing_events_from_cache_or_db()
             )
             event_entry_map.update(missing_events)
 
@@ -729,15 +729,96 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id: str) -> None:
-        self._get_event_cache.invalidate((event_id,))
+    def invalidate_get_event_cache_after_txn(
+        self, txn: LoggingTransaction, event_id: str
+    ) -> None:
+        """
+        Prepares a database transaction to invalidate the get event cache for a given
+        event ID when executed successfully. This is achieved by attaching two callbacks
+        to the transaction, one to invalidate the async cache and one for the in memory
+        sync cache (importantly called in that order).
+
+        Arguments:
+            txn: the database transaction to attach the callbacks to
+            event_id: the event ID to be invalidated from caches
+        """
+
+        txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+        txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+    async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in the asyncronous get event cache, which may be remote.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        await self._get_event_cache.invalidate((event_id,))
+
+    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in local in-memory get event caches.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        self._get_event_cache.invalidate_local((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _get_events_from_cache(
+    async def _get_events_from_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from the caches, both in memory and any external.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = self._get_events_from_local_cache(
+            events, update_metrics=update_metrics
+        )
+
+        missing_event_ids = (e for e in events if e not in event_map)
+        event_map.update(
+            await self._get_events_from_external_cache(
+                events=missing_event_ids,
+                update_metrics=update_metrics,
+            )
+        )
+
+        return event_map
+
+    async def _get_events_from_external_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from any configured external cache.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = await self._get_event_cache.get_external(
+                (event_id,), None, update_metrics=update_metrics
+            )
+            if ret:
+                event_map[event_id] = ret
+
+        return event_map
+
+    def _get_events_from_local_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
-        """Fetch events from the caches.
+        """Fetch events from the local, in memory, caches.
 
         May return rejected events.
 
@@ -749,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = self._get_event_cache.get(
+            ret = self._get_event_cache.get_local(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -771,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                self._get_event_cache.set((event_id,), cache_entry)
+                self._get_event_cache.set_local((event_id,), cache_entry)
 
         return event_map
 
@@ -965,7 +1046,13 @@ class EventsWorkerStore(SQLBaseStore):
                 }
 
                 row_dict = self.db_pool.new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+                    conn,
+                    "do_fetch",
+                    [],
+                    [],
+                    [],
+                    self._fetch_event_rows,
+                    events_to_fetch,
                 )
 
                 # We only want to resolve deferreds from the main thread
@@ -1006,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1148,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.set((event_id,), cache_entry)
+            await self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
             if not redacted_event:
@@ -1340,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
@@ -1382,7 +1490,9 @@ class EventsWorkerStore(SQLBaseStore):
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+            (rid, eid)
+            for (rid, eid) in keys
+            if await self._get_event_cache.contains((eid,))
         }
         results = dict.fromkeys(cache_results, True)
         remaining = [k for k in keys if k not in cache_results]
@@ -1465,7 +1575,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1481,10 +1591,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_all_new_forward_event_rows(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1498,7 +1609,8 @@ class EventsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1507,7 +1619,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1522,11 +1634,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_ex_outlier_stream_rows_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
+                # NB: the next line (inner join) is what makes this query different from
+                # get_all_new_forward_event_rows.
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1541,7 +1656,8 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (last_id, current_id, instance_name))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1995,7 +2111,14 @@ class EventsWorkerStore(SQLBaseStore):
                 AND room_id = ?
                 /* Make sure event is not rejected */
                 AND rejections.event_id IS NULL
-            ORDER BY origin_server_ts %s
+            /**
+             * First sort by the message timestamp. If the message timestamps are the
+             * same, we want the message that logically comes "next" (before/after
+             * the given timestamp) based on the DAG and its topological order (`depth`).
+             * Finally, we can tie-break based on when it was received on the server
+             * (`stream_ordering`).
+             */
+            ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
             LIMIT 1;
         """
 
@@ -2014,7 +2137,8 @@ class EventsWorkerStore(SQLBaseStore):
                 order = "ASC"
 
             txn.execute(
-                sql_template % (comparison_operator, order), (timestamp, room_id)
+                sql_template % (comparison_operator, order, order, order),
+                (timestamp, room_id),
             )
             row = txn.fetchone()
             if row:
@@ -2079,14 +2203,92 @@ class EventsWorkerStore(SQLBaseStore):
     def _get_partial_state_events_batch_txn(
         txn: LoggingTransaction, room_id: str
     ) -> List[str]:
+        # we want to work through the events from oldest to newest, so
+        # we only want events whose prev_events do *not* have partial state - hence
+        # the 'NOT EXISTS' clause in the below.
+        #
+        # This is necessary because ordering by stream ordering isn't quite enough
+        # to ensure that we work from oldest to newest event (in particular,
+        # if an event is initially persisted as an outlier and later de-outliered,
+        # it can end up with a lower stream_ordering than its prev_events).
+        #
+        # Typically this means we'll only return one event per batch, but that's
+        # hard to do much about.
+        #
+        # See also: https://github.com/matrix-org/synapse/issues/13001
         txn.execute(
             """
             SELECT event_id FROM partial_state_events AS pse
                 JOIN events USING (event_id)
-            WHERE pse.room_id = ?
+            WHERE pse.room_id = ? AND
+               NOT EXISTS(
+                  SELECT 1 FROM event_edges AS ee
+                     JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+                     WHERE ee.event_id=pse.event_id
+               )
             ORDER BY events.stream_ordering
             LIMIT 100
             """,
             (room_id,),
         )
         return [row[0] for row in txn]
+
+    def mark_event_rejected_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        rejection_reason: Optional[str],
+    ) -> None:
+        """Mark an event that was previously accepted as rejected, or vice versa
+
+        This can happen, for example, when resyncing state during a faster join.
+
+        Args:
+            txn:
+            event_id: ID of event to update
+            rejection_reason: reason it has been rejected, or None if it is now accepted
+        """
+        if rejection_reason is None:
+            logger.info(
+                "Marking previously-processed event %s as accepted",
+                event_id,
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                "rejections",
+                keyvalues={"event_id": event_id},
+            )
+        else:
+            logger.info(
+                "Marking previously-processed event %s as rejected(%s)",
+                event_id,
+                rejection_reason,
+            )
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": event_id},
+                values={
+                    "reason": rejection_reason,
+                    "last_check": self._clock.time_msec(),
+                },
+            )
+        self.db_pool.simple_update_txn(
+            txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            updatevalues={"rejection_reason": rejection_reason},
+        )
+
+        self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+        # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+        #   call '_send_invalidation_to_replication', but we actually need the other
+        #   end to call _invalidate_local_get_event_cache() rather than (just)
+        #   _get_event_cache.invalidate().
+        #
+        #   One solution might be to (somehow) get the workers to call
+        #   _invalidate_caches_for_event() (though that will invalidate more than
+        #   strictly necessary).
+        #
+        #   https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
deleted file mode 100644
index c15a7136b6..0000000000
--- a/synapse/storage/databases/main/group_server.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING
-
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-
-class GroupServerStore(SQLBaseStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        # Register a legacy groups background update as a no-op.
-        database.updates.register_noop_background_update("local_group_updates_index")
-        super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index d028be16de..9b172a64d8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -37,9 +37,6 @@ from synapse.types import JsonDict, UserID
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
-BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
-    "media_repository_drop_index_wo_method"
-)
 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
     "media_repository_drop_index_wo_method_2"
 )
@@ -111,13 +108,6 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             unique=True,
         )
 
-        # the original impl of _drop_media_index_without_method was broken (see
-        # https://github.com/matrix-org/synapse/issues/8649), so we replace the original
-        # impl with a no-op and run the fixed migration as
-        # media_repository_drop_index_wo_method_2.
-        self.db_pool.updates.register_noop_background_update(
-            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
-        )
         self.db_pool.updates.register_background_update_handler(
             BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
             self._drop_media_index_without_method,
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
                 "initialise_mau_threepids",
                 [],
                 [],
+                [],
                 self._initialise_reserved_users,
                 hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
             )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ba385f9fc4..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
@@ -214,10 +216,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
 
         # Delete all remote non-state events
         for table in (
+            "event_edges",
             "events",
             "event_json",
             "event_auth",
-            "event_edges",
             "event_forward_extremities",
             "event_relations",
             "event_search",
@@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                self._invalidate_get_event_cache(event_id)
+                self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")
 
@@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         Returns:
             The list of state groups to delete.
         """
-        return await self.db_pool.runInteraction(
-            "purge_room", self._purge_room_txn, room_id
+
+        # This first runs the purge transaction with READ_COMMITTED isolation level,
+        # meaning any new rows in the tables will not trigger a serialization error.
+        # We then run the same purge a second time without this isolation level to
+        # purge any of those rows which were added during the first.
+
+        state_groups_to_delete = await self.db_pool.runInteraction(
+            "purge_room",
+            self._purge_room_txn,
+            room_id=room_id,
+            isolation_level=IsolationLevel.READ_COMMITTED,
+        )
+
+        state_groups_to_delete.extend(
+            await self.db_pool.runInteraction(
+                "purge_room",
+                self._purge_room_txn,
+                room_id=room_id,
+            ),
         )
 
+        return state_groups_to_delete
+
     def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
+        # This collides with event persistence so we cannot write new events and metadata into
+        # a room while deleting it or this transaction will fail.
+        if isinstance(self.database_engine, PostgresEngine):
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
+                (room_id,),
+            )
+
         # First, fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d5aefe02b6..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+)
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -50,69 +62,39 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _is_experimental_rule_enabled(
-    rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
-    """Used by `_load_rules` to filter out experimental rules when they
-    have not been enabled.
-    """
-    if (
-        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
-        and not experimental_config.msc3786_enabled
-    ):
-        return False
-    if (
-        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
-        and not experimental_config.msc3772_enabled
-    ):
-        return False
-    return True
-
-
 def _load_rules(
     rawrules: List[JsonDict],
     enabled_map: Dict[str, bool],
     experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = db_to_json(rawrule["conditions"])
-        rule["actions"] = db_to_json(rawrule["actions"])
-        rule["default"] = False
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so copy it. We also filter out
-    # any experimental default push rules that aren't enabled.
-    rules = [
-        rule
-        for rule in list_with_base_rules(ruleslist)
-        if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
-    ]
+) -> FilteredPushRules:
+    """Take the DB rows returned from the DB and convert them into a full
+    `FilteredPushRules` object.
+    """
 
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
+    ruleslist = [
+        PushRule(
+            rule_id=rawrule["rule_id"],
+            priority_class=rawrule["priority_class"],
+            conditions=db_to_json(rawrule["conditions"]),
+            actions=db_to_json(rawrule["actions"]),
+        )
+        for rawrule in rawrules
+    ]
 
-        if rule_id not in enabled_map:
-            continue
-        if rule.get("enabled", True) == bool(enabled_map[rule_id]):
-            continue
+    push_rules = compile_push_rules(ruleslist)
 
-        # Rules are cached across users.
-        rule = dict(rule)
-        rule["enabled"] = bool(enabled_map[rule_id])
-        rules[i] = rule
+    filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
 
-    return rules
+    return filtered_rules
 
 
 # The ABCMeta metaclass ensures that it cannot be instantiated without
 # the abstract methods being implemented.
 class PushRulesWorkerStore(
     ApplicationServiceWorkerStore,
-    ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    ReceiptsWorkerStore,
     EventsWorkerStore,
     SQLBaseStore,
     metaclass=abc.ABCMeta,
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+    async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
 
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
-    @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
     @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
     async def bulk_get_push_rules(
         self, user_ids: Collection[str]
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, FilteredPushRules]:
         if not user_ids:
             return {}
 
-        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+        raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -228,25 +209,25 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("*",),
             desc="bulk_get_push_rules",
+            batch_size=1000,
         )
 
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
         for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
+            raw_rules.setdefault(row["user_name"], []).append(row)
 
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
-        for user_id, rules in results.items():
+        results: Dict[str, FilteredPushRules] = {}
+
+        for user_id, rules in raw_rules.items():
             results[user_id] = _load_rules(
                 rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
             )
 
         return results
 
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
-    )
     async def bulk_get_push_rules_enabled(
         self, user_ids: Collection[str]
     ) -> Dict[str, Dict[str, bool]]:
@@ -261,6 +242,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("user_name", "rule_id", "enabled"),
             desc="bulk_get_push_rules_enabled",
+            batch_size=1000,
         )
         for row in rows:
             enabled = bool(row["enabled"])
@@ -344,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
         user_id: str,
         rule_id: str,
         priority_class: int,
-        conditions: List[Dict[str, str]],
-        actions: List[Union[JsonDict, str]],
+        conditions: Sequence[Mapping[str, str]],
+        actions: Sequence[Union[Mapping[str, Any], str]],
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
@@ -807,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
         self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
         txn.call_after(
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
@@ -816,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
+        self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
@@ -826,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
             rule: A push rule.
         """
         # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
 
+        new_conditions = []
+
         # Change room id in each condition
-        for condition in rule.get("conditions", []):
+        for condition in rule.conditions:
+            new_condition = condition
             if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
+                new_condition = dict(condition)
+                new_condition["pattern"] = new_room_id
+
+            new_conditions.append(new_condition)
 
         # Add the rule for the new room
         await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
+            priority_class=rule.priority_class,
+            conditions=new_conditions,
+            actions=rule.actions,
         )
 
     async def copy_push_rules_from_room_to_room_for_user(
@@ -858,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
         user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
+        for rule, enabled in user_push_rules:
+            if not enabled:
+                continue
+
+            conditions = rule.conditions
             if any(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 21e954ccc1..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import (
     cast,
 )
 
-from synapse.api.constants import EduTypes, ReceiptTypes
+from synapse.api.constants import EduTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -36,6 +36,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
     AbstractStreamIdTracker,
     MultiWriterIdGenerator,
@@ -117,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return self._receipts_id_gen.get_current_token()
 
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_types: Iterable[str]
+        self, user_id: str, room_id: str, receipt_types: Collection[str]
     ) -> Optional[str]:
         """
         Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@@ -125,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch. Earlier receipt types
-                are given priority if multiple receipts point to the same event.
+            receipt_type: The receipt types to fetch.
 
         Returns:
             The latest receipt, if one exists.
         """
-        latest_event_id: Optional[str] = None
-        latest_stream_ordering = 0
-        for receipt_type in receipt_types:
-            result = await self._get_last_receipt_event_id_for_user(
-                user_id, room_id, receipt_type
-            )
-            if result is None:
-                continue
-            event_id, stream_ordering = result
-
-            if latest_event_id is None or latest_stream_ordering < stream_ordering:
-                latest_event_id = event_id
-                latest_stream_ordering = stream_ordering
+        result = await self.db_pool.runInteraction(
+            "get_last_receipt_event_id_for_user",
+            self.get_last_receipt_for_user_txn,
+            user_id,
+            room_id,
+            receipt_types,
+        )
+        if not result:
+            return None
 
-        return latest_event_id
+        event_id, _ = result
+        return event_id
 
-    @cached()
-    async def _get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+    def get_last_receipt_for_user_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_id: str,
+        receipt_types: Collection[str],
     ) -> Optional[Tuple[str, int]]:
         """
-        Fetch the event ID and stream ordering for the latest receipt.
+        Fetch the event ID and stream_ordering for the latest receipt in a room
+        with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch.
 
         Returns:
-            The event ID and stream ordering of the latest receipt, if one exists;
-            otherwise `None`.
+            The event ID and stream ordering of the latest receipt, if one exists.
         """
-        sql = """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "receipt_type", receipt_types
+        )
+
+        sql = f"""
             SELECT event_id, stream_ordering
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
-            WHERE user_id = ?
+            WHERE {clause}
+            AND user_id = ?
             AND room_id = ?
-            AND receipt_type = ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
         """
 
-        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
-            txn.execute(sql, (user_id, room_id, receipt_type))
-            return cast(Optional[Tuple[str, int]], txn.fetchone())
+        args.extend((user_id, room_id))
+        txn.execute(sql, args)
 
-        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+        return cast(Optional[Tuple[str, int]], txn.fetchone())
 
     async def get_receipts_for_user(
         self, user_id: str, receipt_types: Iterable[str]
@@ -576,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
     ) -> None:
         self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self._get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
+
+        # We use this method to invalidate so that we don't end up with circular
+        # dependencies between the receipts and push action stores.
+        self._attempt_to_invalidate_cache(
+            "get_unread_event_push_actions_by_room_for_user", (room_id,)
         )
 
     def process_replication_rows(
@@ -673,17 +682,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        # When updating a local users read receipt, remove any push actions
-        # which resulted from the receipt's event and all earlier events.
-        if (
-            self.hs.is_mine_id(user_id)
-            and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
-            and stream_ordering is not None
-        ):
-            self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
-                txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
-            )
-
         return rx_ts
 
     def _graph_to_linear(
@@ -764,6 +762,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 linearized_event_id,
                 data,
                 stream_id=stream_id,
+                # Read committed is actually beneficial here because we check for a receipt with
+                # greater stream order, and checking the very latest data at select time is better
+                # than the data at transaction start time.
+                isolation_level=IsolationLevel.READ_COMMITTED,
             )
 
         # If the receipt was older than the currently persisted one, nothing to do.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4991360b70..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
     """
 
     user_id: str
+    token_id: int
     is_guest: bool = False
     shadow_banned: bool = False
-    token_id: Optional[int] = None
     device_id: Optional[str] = None
     valid_until_ms: Optional[int] = None
     token_owner: str = attr.ib()
@@ -1805,21 +1805,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             columns=["creation_ts"],
         )
 
-        # we no longer use refresh tokens, but it's possible that some people
-        # might have a background update queued to build this index. Just
-        # clear the background update.
-        self.db_pool.updates.register_noop_background_update(
-            "refresh_tokens_device_index"
-        )
-
         self.db_pool.updates.register_background_update_handler(
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-        self.db_pool.updates.register_noop_background_update(
-            "user_threepids_grandfather"
-        )
-
         self.db_pool.updates.register_background_index_update(
             "user_external_ids_user_id_idx",
             index_name="user_external_ids_user_id_idx",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
             where_clause.append("type = ?")
             where_args.append(event_type)
 
-        if aggregation_key:
-            where_clause.append("aggregation_key = ?")
-            where_args.append(aggregation_key)
-
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 68d4fc2e64..bef66f1992 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    JoinRules,
+    PublicRoomsFilterFields,
+)
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -170,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                   rooms.creator, state.encryption, state.is_federatable AS federatable,
                   rooms.is_public AS public, state.join_rules, state.guest_access,
                   state.history_visibility, curr.current_state_events AS state_events,
-                  state.avatar, state.topic
+                  state.avatar, state.topic, state.room_type
                 FROM rooms
                 LEFT JOIN room_stats_state state USING (room_id)
                 LEFT JOIN room_stats_current curr USING (room_id)
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             desc="get_public_room_ids",
         )
 
+    def _construct_room_type_where_clause(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[Union[str, None], List[str]]:
+        if not room_types:
+            return None, []
+        else:
+            # We use None when we want get rooms without a type
+            is_null_clause = ""
+            if None in room_types:
+                is_null_clause = "OR room_type IS NULL"
+                room_types = [value for value in room_types if value is not None]
+
+            list_clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_type", room_types
+            )
+
+            return f"({list_clause} {is_null_clause})", args
+
     async def count_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
         ignore_non_federatable: bool,
+        search_filter: Optional[dict],
     ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         Args:
             network_tuple
             ignore_non_federatable: If true filters out non-federatable rooms
+            search_filter
         """
 
         def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
+            room_type_clause, args = self._construct_room_type_where_clause(
+                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+                if search_filter
+                else None
+            )
+            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+            query_args += args
+
             if network_tuple:
                 if network_tuple.appservice_id:
                     published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
                         OR history_visibility = 'world_readable'
                     )
+                    {room_type_clause}
                     AND joined_members > 0
             """
 
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if ignore_non_federatable:
             where_clauses.append("is_federatable")
 
-        if search_filter and search_filter.get("generic_search_term", None):
-            search_term = "%" + search_filter["generic_search_term"] + "%"
+        if search_filter and search_filter.get(
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+        ):
+            search_term = (
+                "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+            )
 
             where_clauses.append(
                 """
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 search_term.lower(),
             ]
 
+        room_type_clause, args = self._construct_room_type_where_clause(
+            search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+            if search_filter
+            else None
+        )
+        if room_type_clause:
+            where_clauses.append(room_type_clause)
+        query_args += args
+
         where_clause = ""
         if where_clauses:
             where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         sql = f"""
             SELECT
                 room_id, name, topic, canonical_alias, joined_members,
-                avatar, history_visibility, guest_access, join_rules
+                avatar, history_visibility, guest_access, join_rules, room_type
             FROM (
                 {published_sql}
             ) published
@@ -549,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
               curr.local_users_in_room, rooms.room_version, rooms.creator,
               state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
-              state.guest_access, state.history_visibility, curr.current_state_events
+              state.guest_access, state.history_visibility, curr.current_state_events,
+              state.room_type
             FROM room_stats_state state
             INNER JOIN room_stats_current curr USING (room_id)
             INNER JOIN rooms USING (room_id)
@@ -593,12 +641,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         "version": room[5],
                         "creator": room[6],
                         "encryption": room[7],
-                        "federatable": room[8],
-                        "public": room[9],
+                        # room_stats_state.federatable is an integer on sqlite.
+                        "federatable": bool(room[8]),
+                        # rooms.is_public is an integer on sqlite.
+                        "public": bool(room[9]),
                         "join_rules": room[10],
                         "guest_access": room[11],
                         "history_visibility": room[12],
                         "state_events": room[13],
+                        "room_type": room[14],
                     }
                 )
 
@@ -1109,25 +1160,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         return room_servers
 
     async def clear_partial_state_room(self, room_id: str) -> bool:
-        # this can race with incoming events, so we watch out for FK errors.
-        # TODO(faster_joins): this still doesn't completely fix the race, since the persist process
-        #   is not atomic. I fear we need an application-level lock.
+        """Clears the partial state flag for a room.
+
+        Args:
+            room_id: The room whose partial state flag is to be cleared.
+
+        Returns:
+            `True` if the partial state flag has been cleared successfully.
+
+            `False` if the partial state flag could not be cleared because the room
+            still contains events with partial state.
+        """
         try:
             await self.db_pool.runInteraction(
                 "clear_partial_state_room", self._clear_partial_state_room_txn, room_id
             )
             return True
-        except self.db_pool.engine.module.DatabaseError as e:
-            # TODO(faster_joins): how do we distinguish between FK errors and other errors?
-            logger.warning(
+        except self.db_pool.engine.module.IntegrityError as e:
+            # Assume that any `IntegrityError`s are due to partial state events.
+            logger.info(
                 "Exception while clearing lazy partial-state-room %s, retrying: %s",
                 room_id,
                 e,
             )
             return False
 
-    @staticmethod
-    def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+    def _clear_partial_state_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> None:
         DatabasePool.simple_delete_txn(
             txn,
             table="partial_state_rooms_servers",
@@ -1138,7 +1198,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             table="partial_state_rooms",
             keyvalues={"room_id": room_id},
         )
+        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
 
+    @cached()
     async def is_partial_state_room(self, room_id: str) -> bool:
         """Checks if this room has partial state.
 
@@ -1164,6 +1226,7 @@ class _BackgroundUpdates:
     POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
     REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
     POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+    ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
 
 
 _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1198,6 +1261,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_add_rooms_room_version_column,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+            self._background_add_room_type_column,
+        )
+
         # BG updates to change the type of room_depth.min_depth
         self.db_pool.updates.register_background_update_handler(
             _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1567,6 +1635,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
+    async def _background_add_room_type_column(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update to go and add room_type information to `room_stats_state`
+        table from `event_json` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_add_room_type_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            sql = """
+                SELECT state.room_id, json FROM event_json
+                INNER JOIN current_state_events AS state USING (event_id)
+                WHERE state.room_id > ? AND type = 'm.room.create'
+                ORDER BY state.room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_id_to_create_event_results = txn.fetchall()
+
+            new_last_room_id = None
+            for room_id, event_json in room_id_to_create_event_results:
+                event_dict = db_to_json(event_json)
+
+                room_type = event_dict.get("content", {}).get(
+                    EventContentFields.ROOM_TYPE, None
+                )
+                if isinstance(room_type, str):
+                    self.db_pool.simple_update_txn(
+                        txn,
+                        table="room_stats_state",
+                        keyvalues={"room_id": room_id},
+                        updatevalues={"room_type": room_type},
+                    )
+
+                new_last_room_id = room_id
+
+            if new_last_room_id is None:
+                return True
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                {"room_id": new_last_room_id},
+            )
+
+            return False
+
+        end = await self.db_pool.runInteraction(
+            "_background_add_room_type_column",
+            _background_add_room_type_column_txn,
+        )
+
+        if end:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+            )
+
+        return batch_size
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     def __init__(
@@ -1643,9 +1774,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             servers,
         )
 
-    @staticmethod
     def _store_partial_state_room_txn(
-        txn: LoggingTransaction, room_id: str, servers: Collection[str]
+        self, txn: LoggingTransaction, room_id: str, servers: Collection[str]
     ) -> None:
         DatabasePool.simple_insert_txn(
             txn,
@@ -1660,6 +1790,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             keys=("room_id", "server_name"),
             values=((room_id, s) for s in servers),
         )
+        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
 
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
@@ -1875,9 +2006,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
             sql = """
                 SELECT COUNT(*) as total_event_reports
                 FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
                 {}
                 """.format(
                 where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 31bc8c5601..4f0adb136a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -30,8 +31,6 @@ from typing import (
 import attr
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -56,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -184,34 +184,109 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 self._check_safe_current_state_events_membership_updated_txn,
             )
 
-    @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
+
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
     def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+                SELECT c.state_key FROM current_state_events as c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
         else:
             sql = """
-                SELECT state_key FROM room_memberships as m
+                SELECT c.state_key FROM room_memberships as m
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 INNER JOIN current_state_events as c
                 ON m.event_id = c.event_id
                 AND m.room_id = c.room_id
                 AND m.user_id = c.state_key
                 WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
 
         txn.execute(sql, (room_id, Membership.JOIN))
         return [r[0] for r in txn]
 
+    @cached()
+    def get_user_in_room_with_profile(
+        self, room_id: str, user_id: str
+    ) -> Dict[str, ProfileInfo]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="get_user_in_room_with_profile", list_name="user_ids"
+    )
+    async def get_subset_users_in_room_with_profiles(
+        self, room_id: str, user_ids: Collection[str]
+    ) -> Dict[str, ProfileInfo]:
+        """Get a mapping from user ID to profile information for a list of users
+        in a given room.
+
+        The profile information comes directly from this room's `m.room.member`
+        events, and so may be specific to this room rather than part of a user's
+        global profile. To avoid privacy leaks, the profile data should only be
+        revealed to users who are already in this room.
+
+        Args:
+            room_id: The ID of the room to retrieve the users of.
+            user_ids: a list of users in the room to run the query for
+
+        Returns:
+                A mapping from user ID to ProfileInfo.
+        """
+
+        def _get_subset_users_in_room_with_profiles(
+            txn: LoggingTransaction,
+        ) -> Dict[str, ProfileInfo]:
+            clause, ids = make_in_list_sql_clause(
+                self.database_engine, "c.state_key", user_ids
+            )
+
+            sql = """
+                SELECT state_key, display_name, avatar_url FROM room_memberships as m
+                INNER JOIN current_state_events as c
+                ON m.event_id = c.event_id
+                AND m.room_id = c.room_id
+                AND m.user_id = c.state_key
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
+            """ % (
+                clause,
+            )
+            txn.execute(sql, (room_id, Membership.JOIN, *ids))
+
+            return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_subset_users_in_room_with_profiles",
+            _get_subset_users_in_room_with_profiles,
+        )
+
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room_with_profiles(
         self, room_id: str
@@ -228,6 +303,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             A mapping from user ID to ProfileInfo.
+
+        Preconditions:
+          - There is full state available for the room (it is not partial-stated).
         """
 
         def _get_users_in_room_with_profiles(
@@ -338,6 +416,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
+    async def get_number_joined_users_in_room(self, room_id: str) -> int:
+        return await self.db_pool.simple_select_one_onecol(
+            table="current_state_events",
+            keyvalues={"room_id": room_id, "membership": Membership.JOIN},
+            retcol="COUNT(*)",
+            desc="get_number_joined_users_in_room",
+        )
+
+    @cached()
     async def get_invited_rooms_for_local_user(
         self, user_id: str
     ) -> List[RoomsForUser]:
@@ -416,6 +503,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         user_id: str,
         membership_list: List[str],
     ) -> List[RoomsForUser]:
+        """Get all the rooms for this *local* user where the membership for this user
+        matches one in the membership list.
+
+        Args:
+            user_id: The user ID.
+            membership_list: A list of synapse.api.constants.Membership
+                    values which the user must be in.
+
+        Returns:
+            The RoomsForUser that the user matches the membership types.
+        """
         # Paranoia check.
         if not self.hs.is_mine_id(user_id):
             raise Exception(
@@ -444,6 +542,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results
 
+    @cached(iterable=True)
+    async def get_local_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Retrieves a list of the current roommembers who are local to the server.
+        """
+        return await self.db_pool.simple_select_onecol(
+            table="local_current_membership",
+            keyvalues={"room_id": room_id, "membership": Membership.JOIN},
+            retcol="user_id",
+            desc="get_local_users_in_room",
+        )
+
+    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+        """
+        Check whether a given local user is currently joined to the given room.
+
+        Returns:
+            A boolean indicating whether the user is currently joined to the room
+
+        Raises:
+            Exeption when called with a non-local user to this homeserver
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'check_local_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        (
+            membership,
+            member_event_id,
+        ) = await self.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
+
+        return membership == Membership.JOIN
+
     async def get_local_current_membership_for_user_in_room(
         self, user_id: str, room_id: str
     ) -> Tuple[Optional[str], Optional[str]]:
@@ -476,7 +612,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results_dict.get("membership"), results_dict.get("event_id")
 
-    @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -647,25 +783,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
         return frozenset(r.room_id for r in rooms)
 
-    @cached(
-        max_entries=500000,
-        cache_context=True,
-        iterable=True,
-        prune_unread_entries=False,
+    @cached(max_entries=10000)
+    async def does_pair_of_users_share_a_room(
+        self, user_id: str, other_user_id: str
+    ) -> bool:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
     )
-    async def get_users_who_share_room_with_user(
-        self, user_id: str, cache_context: _CacheContext
+    async def _do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
+    ) -> Mapping[str, Optional[bool]]:
+        """Return mapping from user ID to whether they share a room with the
+        given user.
+
+        Note: `None` and `False` are equivalent and mean they don't share a
+        room.
+        """
+
+        def do_users_share_a_room_txn(
+            txn: LoggingTransaction, user_ids: Collection[str]
+        ) -> Dict[str, bool]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            # This query works by fetching both the list of rooms for the target
+            # user and the set of other users, and then checking if there is any
+            # overlap.
+            sql = f"""
+                SELECT b.state_key
+                FROM (
+                    SELECT room_id FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+                ) AS a
+                INNER JOIN (
+                    SELECT room_id, state_key FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+                ) AS b using (room_id)
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, *args))
+            return {u: True for u, in txn}
+
+        to_return = {}
+        for batch_user_ids in batch_iter(other_user_ids, 1000):
+            res = await self.db_pool.runInteraction(
+                "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+            )
+            to_return.update(res)
+
+        return to_return
+
+    async def do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
     ) -> Set[str]:
+        """Return the set of users who share a room with the first users"""
+
+        user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+        return {u for u, share_room in user_dict.items() if share_room}
+
+    async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
         """Returns the set of users who share a room with `user_id`"""
-        room_ids = await self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
+        room_ids = await self.get_rooms_for_user(user_id)
 
         user_who_share_room = set()
         for room_id in room_ids:
-            user_ids = await self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
+            user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
 
         return user_who_share_room
@@ -694,161 +881,92 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
+    async def get_joined_user_ids_from_state(
+        self, room_id: str, state: StateMap[str]
+    ) -> Set[str]:
+        """
+        For a given set of state IDs, get a set of user IDs in the room.
 
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
+        This method checks the local event cache, before calling
+        `_get_user_ids_from_membership_event_ids` for any uncached events.
+        """
 
-    async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
+        with Measure(self._clock, "get_joined_user_ids_from_state"):
+            users_in_room = set()
+            member_event_ids = [
+                e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+            ]
 
-        assert state_group is not None
-        with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
-            )
+            # We check if we have any of the member event ids in the event cache
+            # before we ask the DB
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
-        self,
-        room_id: str,
-        state_group: Union[object, int],
-        current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
-        event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
-    ) -> Dict[str, ProfileInfo]:
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        assert state_group is not None
-
-        users_in_room = {}
-        member_event_ids = [
-            e_id
-            for key, e_id in current_state_ids.items()
-            if key[0] == EventTypes.Member
-        ]
-
-        if context is not None:
-            # If we have a context with a delta from a previous state group,
-            # check if we also have the result from the previous group in cache.
-            # If we do then we can reuse that result and simply update it with
-            # any membership changes in `delta_ids`
-            if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
-                    (room_id, context.prev_group), None
-                )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
-                    member_event_ids = [
-                        e_id
-                        for key, e_id in context.delta_ids.items()
-                        if key[0] == EventTypes.Member
-                    ]
-                    for etype, state_key in context.delta_ids:
-                        if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
-
-        # We check if we have any of the member event ids in the event cache
-        # before we ask the DB
-
-        # We don't update the event cache hit ratio as it completely throws off
-        # the hit ratio counts. After all, we don't populate the cache if we
-        # miss it here
-        event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
-
-        missing_member_event_ids = []
-        for event_id in member_event_ids:
-            ev_entry = event_map.get(event_id)
-            if ev_entry and not ev_entry.event.rejected_reason:
-                if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
-            else:
-                missing_member_event_ids.append(event_id)
-
-        if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
-                missing_member_event_ids
+            # We don't update the event cache hit ratio as it completely throws off
+            # the hit ratio counts. After all, we don't populate the cache if we
+            # miss it here
+            event_map = self._get_events_from_local_cache(
+                member_event_ids, update_metrics=False
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
-
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
+
+            missing_member_event_ids = []
+            for event_id in member_event_ids:
+                ev_entry = event_map.get(event_id)
+                if ev_entry and not ev_entry.event.rejected_reason:
+                    if ev_entry.event.membership == Membership.JOIN:
+                        users_in_room.add(ev_entry.event.state_key)
+                else:
+                    missing_member_event_ids.append(event_id)
+
+            if missing_member_event_ids:
+                event_to_memberships = (
+                    await self._get_user_ids_from_membership_event_ids(
+                        missing_member_event_ids
                     )
+                )
+                users_in_room.update(
+                    user_id for user_id in event_to_memberships.values() if user_id
+                )
 
-        return users_in_room
+            return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
+        event.
 
         Args:
             event_ids: The member event IDs to lookup
 
         Returns:
-            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id`, or None if event is not a join.
         """
 
         rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
-            batch_size=500,
-            desc="_get_joined_profiles_from_event_ids",
+            batch_size=1000,
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -894,44 +1012,77 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
-        """Get current hosts in room based on current state."""
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+        """
+        Get current hosts in room based on current state.
+
+        The heuristic of sorting by servers who have been in the room the
+        longest is good because they're most likely to have anything we ask
+        about.
+
+        Returns:
+            Returns a list of servers sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         # First we check if we already have `get_users_in_room` in the cache, as
         # we can just calculate result from that
         users = self.get_users_in_room.cache.get_immediate(
             (room_id,), None, update_metrics=False
         )
-        if users is not None:
-            return {get_domain_from_id(u) for u in users}
-
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if users is None and isinstance(self.database_engine, Sqlite3Engine):
             # If we're using SQLite then let's just always use
             # `get_users_in_room` rather than funky SQL.
             users = await self.get_users_in_room(room_id)
-            return {get_domain_from_id(u) for u in users}
+
+        if users is not None:
+            # Because `users` is sorted from lowest -> highest depth, the list
+            # of domains will also be sorted that way.
+            domains: List[str] = []
+            # We use a `Set` just for fast lookups
+            domain_set: Set[str] = set()
+            for u in users:
+                domain = get_domain_from_id(u)
+                if domain not in domain_set:
+                    domain_set.add(domain)
+                    domains.append(domain)
+            return domains
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+            # Returns a list of servers currently joined in the room sorted by
+            # longest in the room first (aka. with the lowest depth). The
+            # heuristic of sorting by servers who have been in the room the
+            # longest is good because they're most likely to have anything we
+            # ask about.
             sql = """
-                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
-                FROM current_state_events
+                SELECT
+                    /* Match the domain part of the MXID */
+                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+                FROM current_state_events c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND room_id = ?
+                    /* Find any join state events in the room */
+                    c.type = 'm.room.member'
+                    AND c.membership = 'join'
+                    AND c.room_id = ?
+                /* Group all state events from the same domain into their own buckets (groups) */
+                GROUP BY server_domain
+                /* Sorted by lowest depth first */
+                ORDER BY min(e.depth) ASC;
             """
             txn.execute(sql, (room_id,))
-            return {d for d, in txn}
+            return [d for d, in txn]
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room", get_current_hosts_in_room_txn
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -944,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -952,6 +1103,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1006,12 +1158,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                joined_user_ids = await self.get_joined_user_ids_from_state(
+                    room_id, state
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
@@ -1090,6 +1242,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
+    async def is_locally_forgotten_room(self, room_id: str) -> bool:
+        """Returns whether all local users have forgotten this room_id.
+
+        Args:
+            room_id: The room ID to query.
+
+        Returns:
+            Whether the room is forgotten.
+        """
+
+        sql = """
+            SELECT count(*) > 0 FROM local_current_membership
+            INNER JOIN room_memberships USING (room_id, event_id)
+            WHERE
+                room_id = ?
+                AND forgotten = 0;
+        """
+
+        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+        # `count(*)` returns always an integer
+        # If any rows still exist it means someone has not forgotten this room yet
+        return not rows[0][0]
+
     async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
         """Get all rooms that the user has ever been in.
 
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 78e0773b2a..f6e24b68d2 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -113,7 +113,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
-    EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
     EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
 
@@ -132,15 +131,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
         )
 
-        # we used to have a background update to turn the GIN index into a
-        # GIST one; we no longer do that (obviously) because we actually want
-        # a GIN index. However, it's possible that some people might still have
-        # the background update queued, so we register a handler to clear the
-        # background update.
-        self.db_pool.updates.register_noop_background_update(
-            self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
-        )
-
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5e6efbd0fc..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # anything that was rejected should have the same state as its
         # predecessor.
         if context.rejected:
-            assert context.state_group == context.state_group_before_event
+            state_group = context.state_group_before_event
+        else:
+            state_group = context.state_group
 
         self.db_pool.simple_update_txn(
             txn,
             table="event_to_state_groups",
             keyvalues={"event_id": event.event_id},
-            updatevalues={"state_group": context.state_group},
+            updatevalues={"state_group": state_group},
         )
 
+        # the event may now be rejected where it was not before, or vice versa,
+        # in which case we need to update the rejected flags.
+        if bool(context.rejected) != (event.rejected_reason is not None):
+            self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
         self.db_pool.simple_delete_one_txn(
             txn,
             table="partial_state_events",
@@ -435,11 +442,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         # TODO(faster_joins): need to do something about workers here
+        #   https://github.com/matrix-org/synapse/issues/12994
         txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
         txn.call_after(
             self._get_state_group_for_event.prefill,
             (event.event_id,),
-            context.state_group,
+            state_group,
         )
 
 
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index b95dbef678..b4c652acf3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 from typing_extensions import Counter
 
@@ -120,11 +120,6 @@ class StatsStore(StateDeltasStore):
         self.db_pool.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
-        # we no longer need to perform clean-up, but we will give ourselves
-        # the potential to reintroduce it in the future – so documentation
-        # will still encourage the use of this no-op handler.
-        self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
-        self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
 
     async def _populate_stats_process_users(
         self, progress: JsonDict, batch_size: int
@@ -243,6 +238,7 @@ class StatsStore(StateDeltasStore):
         * avatar
         * canonical_alias
         * guest_access
+        * room_type
 
         A is_federatable key can also be included with a boolean value.
 
@@ -268,6 +264,7 @@ class StatsStore(StateDeltasStore):
             "avatar",
             "canonical_alias",
             "guest_access",
+            "room_type",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
@@ -300,6 +297,7 @@ class StatsStore(StateDeltasStore):
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
             allow_none=True,
+            desc="get_earliest_token_for_stats",
         )
 
     async def bulk_update_stats_delta(
@@ -576,7 +574,7 @@ class StatsStore(StateDeltasStore):
 
         state_event_map = await self.get_events(event_ids, get_prev_content=False)  # type: ignore[attr-defined]
 
-        room_state = {
+        room_state: Dict[str, Union[None, bool, str]] = {
             "join_rules": None,
             "history_visibility": None,
             "encryption": None,
@@ -585,6 +583,7 @@ class StatsStore(StateDeltasStore):
             "avatar": None,
             "canonical_alias": None,
             "is_federatable": True,
+            "room_type": None,
         }
 
         for event in state_event_map.values():
@@ -608,6 +607,9 @@ class StatsStore(StateDeltasStore):
                 room_state["is_federatable"] = (
                     event.content.get(EventContentFields.FEDERATE, True) is True
                 )
+                room_type = event.content.get(EventContentFields.ROOM_TYPE)
+                if isinstance(room_type, str):
+                    room_state["room_type"] = room_type
 
         await self.update_room_state(room_id, room_state)
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -46,16 +46,19 @@ from typing import (
     Set,
     Tuple,
     cast,
+    overload,
 )
 
 import attr
 from frozendict import frozendict
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -795,6 +798,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         return RoomStreamToken(topo, stream_ordering)
 
+    @overload
+    def get_stream_id_for_event_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none: Literal[False] = False,
+    ) -> int:
+        ...
+
+    @overload
+    def get_stream_id_for_event_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        ...
+
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
@@ -1002,8 +1023,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
+        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
         """Get all new events
 
         Returns all events with from_id < stream_ordering <= current_id.
@@ -1012,19 +1033,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
+            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events), where `next_id` is the next value to
-            pass as `from_id` (it will either be the stream_ordering of the
-            last returned event, or, if fewer than `limit` events were found,
-            the `current_id`).
+            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            is the next value to pass as `from_id` (it will either be the
+            stream_ordering of the last returned event, or, if fewer than `limit`
+            events were found, the `current_id`). The `event_to_received_ts` is
+            a dictionary mapping event ID to the event `received_ts`.
         """
 
         def get_all_new_events_stream_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
+        ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
-                "SELECT e.stream_ordering, e.event_id"
+                "SELECT e.stream_ordering, e.event_id, e.received_ts"
                 " FROM events AS e"
                 " WHERE"
                 " ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1039,15 +1062,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if len(rows) == limit:
                 upper_bound = rows[-1][0]
 
-            return upper_bound, [row[1] for row in rows]
+            event_to_received_ts: Dict[str, Optional[int]] = {
+                row[1]: row[2] for row in rows
+            }
+            return upper_bound, event_to_received_ts
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
+        upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(
+            event_to_received_ts.keys(),
+            get_prev_content=get_prev_content,
+        )
 
-        return upper_bound, events
+        return upper_bound, events, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions:
@@ -1318,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, next_token
 
+    @trace
     async def paginate_room_events(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.types import MutableStateMap, StateMap
+from synapse.util.caches import intern_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                 txn.execute(sql % (where_clause,), args)
                 for row in txn:
                     typ, state_key, event_id = row
-                    key = (typ, state_key)
+                    key = (intern_string(typ), intern_string(state_key))
                     results[group][key] = event_id
         else:
             max_entries_returned = state_filter.max_entries_returned()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..bb64543c1f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 requests state from the cache, if False we need to query the DB for the
                 missing state.
         """
-        cache_entry = cache.get(group)
+        # If we are asked explicitly for a subset of keys, we only ask for those
+        # from the cache. This ensures that the `DictionaryCache` can make
+        # better decisions about what to cache and what to expire.
+        dict_keys = None
+        if not state_filter.has_wildcards():
+            dict_keys = state_filter.concrete_types()
+
+        cache_entry = cache.get(group, dict_keys=dict_keys)
         state_dict_ids = cache_entry.value
 
         if cache_entry.full or state_filter.is_full():
@@ -400,14 +407,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
+        At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+        `prev_group` is not None, `delta_ids` must also not be None.
+
         Args:
             event_id: The event ID for which the state was calculated
             room_id
-            prev_group: A previous state group for the room, optional.
+            prev_group: A previous state group for the room.
             delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
@@ -418,10 +428,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn: LoggingTransaction) -> int:
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
+        if prev_group is None and current_state_ids is None:
+            raise Exception("current_state_ids and prev_group can't both be None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("delta_ids is None when prev_group is not None")
+
+        def insert_delta_group_txn(
+            txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+        ) -> Optional[int]:
+            """Try and persist the new group as a delta.
+
+            Requires that we have the state as a delta from a previous state group.
+
+            Returns:
+                The state group if successfully created, or None if the state
+                needs to be persisted as a full state.
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            # if the chain of state group deltas is going too long, we fall back to
+            # persisting a complete state group.
+            potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if potential_hops >= MAX_STATE_DELTA_HOPS:
+                return None
 
             state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
@@ -431,51 +472,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
             )
 
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self.db_pool.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                assert delta_ids is not None
-
-                self.db_pool.simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_group_edges",
+                values={"state_group": state_group, "prev_state_group": prev_group},
+            )
 
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in delta_ids.items()
-                    ],
-                )
-            else:
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in current_state_ids.items()
-                    ],
-                )
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in delta_ids.items()
+                ],
+            )
+
+            return state_group
+
+        def insert_full_state_txn(
+            txn: LoggingTransaction, current_state_ids: StateMap[str]
+        ) -> int:
+            """Persist the full state, returning the new state group."""
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in current_state_ids.items()
+                ],
+            )
 
             # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
@@ -491,7 +526,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_members_cache.update,
                 self._state_group_members_cache.sequence,
                 key=state_group,
-                value=dict(current_member_state_ids),
+                value=current_member_state_ids,
             )
 
             current_non_member_state_ids = {
@@ -503,13 +538,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_non_member_state_ids),
+                value=current_non_member_state_ids,
             )
 
             return state_group
 
+        if prev_group is not None:
+            state_group = await self.db_pool.runInteraction(
+                "store_state_group.insert_delta_group",
+                insert_delta_group_txn,
+                prev_group,
+                delta_ids,
+            )
+            if state_group is not None:
+                return state_group
+
+        # We're going to persist the state as a complete group rather than
+        # a delta, so first we need to ensure we have loaded the state map
+        # from the database.
+        if current_state_ids is None:
+            assert prev_group is not None
+            assert delta_ids is not None
+            groups = await self._get_state_for_groups([prev_group])
+            current_state_ids = dict(groups[prev_group])
+            current_state_ids.update(delta_ids)
+
         return await self.db_pool.runInteraction(
-            "store_state_group", _store_state_group_txn
+            "store_state_group.insert_full_state",
+            insert_full_state_txn,
+            current_state_ids,
         )
 
     async def purge_unreferenced_state_groups(
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index f51b3d228e..a182e8a098 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -11,11 +11,35 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Mapping
+from typing import Any, Mapping, NoReturn
 
 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
-from .postgres import PostgresEngine
-from .sqlite import Sqlite3Engine
+
+# The classes `PostgresEngine` and `Sqlite3Engine` must always be importable, because
+# we use `isinstance(engine, PostgresEngine)` to write different queries for postgres
+# and sqlite. But the database driver modules are both optional: they may not be
+# installed. To account for this, create dummy classes on import failure so we can
+# still run `isinstance()` checks.
+try:
+    from .postgres import PostgresEngine
+except ImportError:
+
+    class PostgresEngine(BaseDatabaseEngine):  # type: ignore[no-redef]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+            raise RuntimeError(
+                f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
+            )
+
+
+try:
+    from .sqlite import Sqlite3Engine
+except ImportError:
+
+    class Sqlite3Engine(BaseDatabaseEngine):  # type: ignore[no-redef]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+            raise RuntimeError(
+                f"Cannot create {cls.__name__} -- sqlite3 module is not installed"
+            )
 
 
 def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
@@ -30,4 +54,10 @@ def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
 
-__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
+__all__ = [
+    "create_engine",
+    "BaseDatabaseEngine",
+    "PostgresEngine",
+    "Sqlite3Engine",
+    "IncorrectDatabaseSetup",
+]
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 391f8ed24a..517f9d5f98 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,6 +15,8 @@
 import logging
 from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
 
+import psycopg2.extensions
+
 from synapse.storage.engines._base import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -23,18 +25,14 @@ from synapse.storage.engines._base import (
 from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
-    import psycopg2  # noqa: F401
-
     from synapse.storage.database import LoggingDatabaseConnection
 
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
+class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
     def __init__(self, database_config: Mapping[str, Any]):
-        import psycopg2.extensions
-
         super().__init__(psycopg2, database_config)
         psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
@@ -69,7 +67,9 @@ class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
         return collation, ctype
 
     def check_database(
-        self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
+        self,
+        db_conn: psycopg2.extensions.connection,
+        allow_outdated_version: bool = False,
     ) -> None:
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
@@ -176,8 +176,6 @@ class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
         return True
 
     def is_deadlock(self, error: Exception) -> bool:
-        import psycopg2.extensions
-
         if isinstance(error, psycopg2.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
             # "40001" serialization_failure
@@ -185,7 +183,7 @@ class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
             return error.pgcode in ["40001", "40P01"]
         return False
 
-    def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
+    def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
         return bool(conn.closed)
 
     def lock_table(self, txn: Cursor, table: str) -> None:
@@ -205,18 +203,16 @@ class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
-    def in_transaction(self, conn: "psycopg2.connection") -> bool:
-        import psycopg2.extensions
-
+    def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
         return conn.status != psycopg2.extensions.STATUS_READY
 
     def attempt_to_set_autocommit(
-        self, conn: "psycopg2.connection", autocommit: bool
+        self, conn: psycopg2.extensions.connection, autocommit: bool
     ) -> None:
         return conn.set_session(autocommit=autocommit)
 
     def attempt_to_set_isolation_level(
-        self, conn: "psycopg2.connection", isolation_level: Optional[int]
+        self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
     ) -> None:
         if isolation_level is None:
             isolation_level = self.default_isolation_level
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c33df42084..09a2b58f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,8 +23,7 @@ from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine
-from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
 from synapse.storage.types import Cursor
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 5843fae605..256f745dc0 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 71  # remember to update the list below when updating
+SCHEMA_VERSION = 72  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -71,14 +71,18 @@ Changes in SCHEMA_VERSION = 70:
 Changes in SCHEMA_VERSION = 71:
     - event_edges.room_id is no longer read from.
     - Tables related to groups are no longer accessed.
+
+Changes in SCHEMA_VERSION = 72:
+    - event_edges.(room_id, is_state) are no longer written to.
+    - Tables related to groups are dropped.
+    - Unused column application_services_state.last_txn is dropped
 """
 
 
 SCHEMA_COMPAT_VERSION = (
-    # We now assume that `device_lists_changes_in_room` has been filled out for
-    # recent device_list_updates.
-    # ... and that `application_services_state.last_txn` is not used.
-    69
+    # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72
+    # could break.
+    72
 )
 """Limit on how far the synapse codebase can be rolled back without breaking db compat
 
diff --git a/synapse/storage/schema/main/delta/40/event_push_summary.sql b/synapse/storage/schema/main/delta/40/event_push_summary.sql
index 3918f0b794..499bf60178 100644
--- a/synapse/storage/schema/main/delta/40/event_push_summary.sql
+++ b/synapse/storage/schema/main/delta/40/event_push_summary.sql
@@ -13,9 +13,10 @@
  * limitations under the License.
  */
 
--- Aggregate of old notification counts that have been deleted out of the
--- main event_push_actions table. This count does not include those that were
--- highlights, as they remain in the event_push_actions table.
+-- Aggregate of notification counts up to `stream_ordering`, including those
+-- that may have been deleted out of the main event_push_actions table. This
+-- count does not include those that were highlights, as they remain in the
+-- event_push_actions table.
 CREATE TABLE event_push_summary (
     user_id TEXT NOT NULL,
     room_id TEXT NOT NULL,
diff --git a/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.postgres b/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.postgres
new file mode 100644
index 0000000000..f32f445858
--- /dev/null
+++ b/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.postgres
@@ -0,0 +1,43 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're going to stop populating event_edges.room_id and event_edges.is_state,
+-- which means we now need to give them defaults.
+
+-- We also drop the exising unique constraint which spans all four columns. Franky
+-- it's not doing much, and there are other indexes on event_id and prev_event_id.
+-- Later on we introduce a proper unique constraint on (event_id, prev_event_id).
+--
+-- We also add a foreign key constraint (which will be enforced for new rows), but
+-- don't yet validate it for existing rows (since that's slow, and we haven't yet
+-- checked that all the rows are valid)
+
+ALTER TABLE event_edges
+   ALTER room_id DROP NOT NULL,
+   ALTER is_state SET DEFAULT FALSE,
+   DROP CONSTRAINT IF EXISTS event_edges_event_id_prev_event_id_room_id_is_state_key,
+   ADD CONSTRAINT event_edges_event_id_fkey FOREIGN KEY (event_id) REFERENCES events(event_id) NOT VALID;
+
+-- In the background, we drop any rows with is_state=True. These may have been
+-- added a long time ago, but they are no longer used.
+--
+-- We also drop rows that do not correspond to entries in `events`, and finally
+-- validate the foreign key.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7101, 'event_edges_drop_invalid_rows', '{}');
+
+-- We'll then create a new unique index on (event_id, prev_event_id).
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (7101, 'event_edges_replace_index', '{}', 'event_edges_drop_invalid_rows');
diff --git a/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.sqlite b/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.sqlite
new file mode 100644
index 0000000000..0bb86edd2a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/71/01rebuild_event_edges.sql.sqlite
@@ -0,0 +1,47 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're going to stop populating event_edges.room_id and event_edges.is_state,
+-- which means we now need to give them defaults.
+--
+-- We also take the opportunity to:
+--   - drop any rows with is_state=True (these were populated a long time ago, but
+--     are no longer used.)
+--   - drop any rows which do not correspond to entries in `events`
+--   - tighten the unique index so that it applies just to (event_id, prev_event_id)
+--   - drop the "ev_edges_id" index, which is redundant to the above.
+--   - add a foreign key constraint from event_id to `events`
+
+CREATE TABLE new_event_edges (
+  event_id TEXT NOT NULL,
+  prev_event_id TEXT NOT NULL,
+  room_id TEXT NULL,
+  is_state BOOL NOT NULL DEFAULT 0,
+  FOREIGN KEY(event_id) REFERENCES events(event_id)
+);
+
+INSERT INTO new_event_edges
+    SELECT ee.event_id, ee.prev_event_id, ee.room_id, ee.is_state
+    FROM event_edges ee JOIN events ev USING (event_id)
+    WHERE NOT ee.is_state;
+
+DROP TABLE event_edges;
+
+ALTER TABLE new_event_edges RENAME TO event_edges;
+
+CREATE UNIQUE INDEX event_edges_event_id_prev_event_id_idx
+  ON event_edges (event_id, prev_event_id);
+
+CREATE INDEX ev_edges_prev_id ON event_edges (prev_event_id);
diff --git a/synapse/storage/schema/main/delta/71/01remove_noop_background_updates.sql b/synapse/storage/schema/main/delta/71/01remove_noop_background_updates.sql
new file mode 100644
index 0000000000..fa96ac50c2
--- /dev/null
+++ b/synapse/storage/schema/main/delta/71/01remove_noop_background_updates.sql
@@ -0,0 +1,61 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Clean-up background updates which should no longer be run. Previously these
+-- used the (now removed) register_noop_background_update method.
+
+-- Used to be a background update that deletes all device_inboxes for deleted
+-- devices.
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+-- Used to be a background update that deletes all device_inboxes for hidden
+-- devices.
+DELETE FROM background_updates WHERE update_name = 'remove_hidden_devices_from_device_inbox';
+
+-- A pair of background updates that were added during the 1.14 release cycle,
+-- but replaced with 58/06dlols_unique_idx.py
+DELETE FROM background_updates WHERE update_name = 'device_lists_outbound_last_success_unique_idx';
+DELETE FROM background_updates WHERE update_name = 'drop_device_lists_outbound_last_success_non_unique_idx';
+
+-- The event_thread_relation background update was replaced with the
+-- event_arbitrary_relations one, which handles any relation to avoid
+-- needed to potentially crawl the entire events table in the future.
+DELETE FROM background_updates WHERE update_name = 'event_thread_relation';
+
+-- A legacy groups background update.
+DELETE FROM background_updates WHERE update_name = 'local_group_updates_index';
+
+-- The original impl of _drop_media_index_without_method was broken (see
+-- https://github.com/matrix-org/synapse/issues/8649), so we replace the original
+-- impl with a no-op and run the fixed migration as
+-- media_repository_drop_index_wo_method_2.
+DELETE FROM background_updates WHERE update_name = 'media_repository_drop_index_wo_method';
+
+-- We no longer use refresh tokens, but it's possible that some people
+-- might have a background update queued to build this index. Just
+-- clear the background update.
+DELETE FROM background_updates WHERE update_name = 'refresh_tokens_device_index';
+
+DELETE FROM background_updates WHERE update_name = 'user_threepids_grandfather';
+
+-- We used to have a background update to turn the GIN index into a
+-- GIST one; we no longer do that (obviously) because we actually want
+-- a GIN index. However, it's possible that some people might still have
+-- the background update queued, so we register a handler to clear the
+-- background update.
+DELETE FROM background_updates WHERE update_name = 'event_search_postgres_gist';
+
+-- We no longer need to perform clean-up.
+DELETE FROM background_updates WHERE update_name = 'populate_stats_cleanup';
+DELETE FROM background_updates WHERE update_name = 'populate_stats_prepare';
diff --git a/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql
new file mode 100644
index 0000000000..9cdcea21ae
--- /dev/null
+++ b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a unique index to `event_push_summary`
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7002, 'event_push_summary_unique_index', '{}');
diff --git a/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
new file mode 100644
index 0000000000..d5e0765471
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_state ADD room_type TEXT;
+
+INSERT INTO background_updates (update_name, progress_json)
+    VALUES ('add_room_type_column', '{}');
diff --git a/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql b/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql
new file mode 100644
index 0000000000..e45db61529
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01event_push_summary_receipt.sql
@@ -0,0 +1,35 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column that records the position of the read receipt for the user at
+-- the time we summarised the push actions. This is used to check if the counts
+-- are up to date after a new read receipt has been sent.
+--
+-- Null means that we can skip that check, as the row was written by an older
+-- version of Synapse that updated `event_push_summary` synchronously when
+-- persisting a new read receipt
+ALTER TABLE event_push_summary ADD COLUMN last_receipt_stream_ordering BIGINT;
+
+
+-- Tracks which new receipts we've handled
+CREATE TABLE event_push_summary_last_receipt_stream_id (
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_id BIGINT NOT NULL,
+    CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_last_receipt_stream_id (stream_id)
+  SELECT COALESCE(MAX(stream_id), 0)
+  FROM receipts_linearized;
diff --git a/synapse/storage/schema/main/delta/72/02event_push_actions_index.sql b/synapse/storage/schema/main/delta/72/02event_push_actions_index.sql
new file mode 100644
index 0000000000..cd0c11d951
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/02event_push_actions_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add an index to `event_push_actions` to make deleting old non-highlight push
+-- actions faster.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7202, 'event_push_actions_stream_highlight_index', '{}');
diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
new file mode 100644
index 0000000000..55a5d092cc
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
@@ -0,0 +1,47 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine, *args, **kwargs):
+    """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`"""
+
+    # we know that any new events will have the columns populated (and that has been
+    # the case since schema_version 68, so there is no chance of rolling back now).
+    #
+    # So, we only need to make sure that existing rows are updated. We read the
+    # current min and max stream orderings, since that is guaranteed to include all
+    # the events that were stored before the new columns were added.
+    cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events")
+    (min_stream_ordering, max_stream_ordering) = cur.fetchone()
+
+    if min_stream_ordering is None:
+        # no rows, nothing to do.
+        return
+
+    cur.execute(
+        "INSERT into background_updates (ordering, update_name, progress_json)"
+        " VALUES (7203, 'events_populate_state_key_rejections', ?)",
+        (
+            json.dumps(
+                {
+                    "min_stream_ordering_exclusive": min_stream_ordering - 1,
+                    "max_stream_ordering_inclusive": max_stream_ordering,
+                }
+            ),
+        ),
+    )
diff --git a/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
new file mode 100644
index 0000000000..0da668aa3a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- event_reference_hashes is unused, so we can drop it
+DROP TABLE event_reference_hashes;
diff --git a/synapse/storage/schema/main/delta/72/03remove_groups.sql b/synapse/storage/schema/main/delta/72/03remove_groups.sql
new file mode 100644
index 0000000000..b7c5894de8
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03remove_groups.sql
@@ -0,0 +1,31 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Remove the tables which powered the unspecced groups/communities feature.
+DROP TABLE IF EXISTS group_attestations_remote;
+DROP TABLE IF EXISTS group_attestations_renewals;
+DROP TABLE IF EXISTS group_invites;
+DROP TABLE IF EXISTS group_roles;
+DROP TABLE IF EXISTS group_room_categories;
+DROP TABLE IF EXISTS group_rooms;
+DROP TABLE IF EXISTS group_summary_roles;
+DROP TABLE IF EXISTS group_summary_room_categories;
+DROP TABLE IF EXISTS group_summary_rooms;
+DROP TABLE IF EXISTS group_summary_users;
+DROP TABLE IF EXISTS group_users;
+DROP TABLE IF EXISTS groups;
+DROP TABLE IF EXISTS local_group_membership;
+DROP TABLE IF EXISTS local_group_updates;
+DROP TABLE IF EXISTS remote_profile_cache;
diff --git a/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres
new file mode 100644
index 0000000000..13d47de9e6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop unused column application_services_state.last_txn
+ALTER table application_services_state DROP COLUMN last_txn;
\ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite
new file mode 100644
index 0000000000..3be1c88d72
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite
@@ -0,0 +1,40 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop unused column application_services_state.last_txn
+
+CREATE TABLE application_services_state2 (
+    as_id TEXT PRIMARY KEY NOT NULL,
+    state VARCHAR(5),
+    read_receipt_stream_id BIGINT,
+    presence_stream_id BIGINT,
+    to_device_stream_id BIGINT,
+    device_list_stream_id BIGINT
+);
+
+
+INSERT INTO application_services_state2 (
+    as_id,
+    state,
+    read_receipt_stream_id,
+    presence_stream_id,
+    to_device_stream_id,
+    device_list_stream_id
+)
+SELECT as_id, state, read_receipt_stream_id, presence_stream_id, to_device_stream_id,  device_list_stream_id
+FROM application_services_state;
+
+DROP TABLE application_services_state;
+ALTER TABLE application_services_state2 RENAME TO application_services_state;
diff --git a/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql b/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql
new file mode 100644
index 0000000000..36b41049cd
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop previously received private read receipts so they do not accidentally
+-- get leaked to other users.
+DELETE FROM receipts_linearized WHERE receipt_type = 'org.matrix.msc2285.read.private';
+DELETE FROM receipts_graph WHERE receipt_type = 'org.matrix.msc2285.read.private';
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 96aaffb53c..0004d955b4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,14 +539,6 @@ class StateFilter:
             is_mine_id: a callable which confirms if a given state_key matches a mxid
                of a local user
         """
-
-        # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
-        #  there may be circumstances in which we return a piece of state that, once we
-        #  resync the state, we discover is invalid. For example: if it turns out that
-        #  the sender of a piece of state wasn't actually in the room, then clearly that
-        #  state shouldn't have been returned.
-        #  We should at least add some tests around this to see what happens.
-
         # if we haven't requested membership events, then it depends on the value of
         # 'include_others'
         if EventTypes.Member not in self.types:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..2dfe4c0b66 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 # Cast safety: this corresponds to the types returned by the query above.
                 rows.extend(cast(Iterable[Tuple[str, int]], cur))
 
-            # Sort so that we handle rows in order for each instance.
-            rows.sort()
+            # Sort by stream_id (ascending, lowest -> highest) so that we handle
+            # rows in order for each instance because we don't want to overwrite
+            # the current_position of an instance to a lower stream ID than
+            # we're actually at.
+            def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+                (instance, stream_id) = row
+                # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+                # not supported between instances of 'NoneType' and 'X'` error.
+                return stream_id
+
+            rows.sort(key=sort_by_stream_id_key_func)
 
             with self._lock:
                 for (
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 211437cfaa..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialStateEventsTracker.await_full_state")
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialCurrentStateTracker.await_full_state")
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.
@@ -166,6 +169,7 @@ class PartialCurrentStateTracker:
             logger.info(
                 "Awaiting un-partial-stating of room %s",
                 room_id,
+                stack_info=True,
             )
 
             await make_deferred_yieldable(d)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 54e0b1a23b..bcd840bd88 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
 from synapse.types import StreamToken
 
@@ -69,6 +70,7 @@ class EventSources:
         )
         return token
 
+    @trace
     async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
         """Get the current token for a given room to be used to paginate
         events.
diff --git a/synapse/types.py b/synapse/types.py
index 0586d2cbb9..668d48d646 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -267,7 +267,6 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
             )
 
         domain = parts[1]
-
         # This code will need changing if we want to support multiple domain
         # names on one HS
         return cls(localpart=parts[0], domain=domain)
@@ -279,6 +278,8 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
     @classmethod
     def is_valid(cls: Type[DS], s: str) -> bool:
         """Parses the input string and attempts to ensure it is valid."""
+        # TODO: this does not reject an empty localpart or an overly-long string.
+        # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar
         try:
             obj = cls.from_string(s)
             # Apply additional validation to the domain. This is only done
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 6323d452e7..a90f08dd4c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Generator, Optional
 import attr
 from frozendict import frozendict
 from matrix_common.versionstring import get_distribution_version_string
+from typing_extensions import ParamSpec
 
 from twisted.internet import defer, task
 from twisted.internet.defer import Deferred
@@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure:
     return failure.value.subFailure  # type: ignore[union-attr]  # Issue in Twisted's annotations
 
 
+P = ParamSpec("P")
+
+
 @attr.s(slots=True)
 class Clock:
     """
@@ -110,7 +114,7 @@ class Clock:
         return int(self.time() * 1000)
 
     def looping_call(
-        self, f: Callable, msec: float, *args: Any, **kwargs: Any
+        self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
     ) -> LoopingCall:
         """Call a function repeatedly.
 
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 42f6abb5e1..bdf9b0dc8c 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -34,10 +34,10 @@ TRACK_MEMORY_USAGE = False
 caches_by_name: Dict[str, Sized] = {}
 collectors_by_name: Dict[str, "CacheMetric"] = {}
 
-cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
-cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
-cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+cache_size = Gauge("synapse_util_caches_cache_size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache_hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache_evicted_size", "", ["name", "reason"])
+cache_total = Gauge("synapse_util_caches_cache_total", "", ["name"])
 cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
 cache_memory_usage = Gauge(
     "synapse_util_caches_cache_size_bytes",
@@ -45,12 +45,12 @@ cache_memory_usage = Gauge(
     ["name"],
 )
 
-response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
-response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_size = Gauge("synapse_util_caches_response_cache_size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache_hits", "", ["name"])
 response_cache_evicted = Gauge(
-    "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
+    "synapse_util_caches_response_cache_evicted_size", "", ["name", "reason"]
 )
-response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+response_cache_total = Gauge("synapse_util_caches_response_cache_total", "", ["name"])
 
 
 class EvictionReason(Enum):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import enum
 import threading
 from typing import (
     Callable,
+    Collection,
+    Dict,
     Generic,
-    Iterable,
     MutableMapping,
     Optional,
+    Set,
     Sized,
+    Tuple,
     TypeVar,
     Union,
     cast,
@@ -31,7 +35,6 @@ from typing import (
 from prometheus_client import Gauge
 
 from twisted.internet import defer
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache: Union[
-            TreeCache, "MutableMapping[KT, CacheEntry]"
+            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
         ] = cache_type()
 
         def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
         Raises:
             KeyError if the key is not found in the cache
         """
-        callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
-            val.callbacks.update(callbacks)
+            val.add_invalidation_callback(key, callback)
             if update_metrics:
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred.observe()
+            return val.deferred(key)
+
+        callbacks = (callback,) if callback else ()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
         else:
             return defer.succeed(val2)
 
+    def get_bulk(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+        """Bulk lookup of items in the cache.
+
+        Returns:
+            A 3-tuple of:
+                1. a dict of key/value of items already cached;
+                2. a deferred that resolves to a dict of key/value of items
+                   we're already fetching; and
+                3. a collection of keys that don't appear in the previous two.
+        """
+
+        # The cached results
+        cached = {}
+
+        # List of pending deferreds
+        pending = []
+
+        # Dict that gets filled out when the pending deferreds complete
+        pending_results = {}
+
+        # List of keys that aren't in either cache
+        missing = []
+
+        callbacks = (callback,) if callback else ()
+
+        for key in keys:
+            # Check if its in the main cache.
+            immediate_value = self.cache.get(
+                key,
+                _Sentinel.sentinel,
+                callbacks=callbacks,
+            )
+            if immediate_value is not _Sentinel.sentinel:
+                cached[key] = immediate_value
+                continue
+
+            # Check if its in the pending cache
+            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+            if pending_value is not _Sentinel.sentinel:
+                pending_value.add_invalidation_callback(key, callback)
+
+                def completed_cb(value: VT, key: KT) -> VT:
+                    pending_results[key] = value
+                    return value
+
+                # Add a callback to fill out `pending_results` when that completes
+                d = pending_value.deferred(key).addCallback(completed_cb, key)
+                pending.append(d)
+                continue
+
+            # Not in either cache
+            missing.append(key)
+
+        # If we've got pending deferreds, squash them into a single one that
+        # returns `pending_results`.
+        pending_deferred = None
+        if pending:
+            pending_deferred = defer.gatherResults(
+                pending, consumeErrors=True
+            ).addCallback(lambda _: pending_results)
+
+        return (cached, pending_deferred, missing)
+
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
     ) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
             value: a deferred which will complete with a result to add to the cache
             callback: An optional callback to be called when the entry is invalidated
         """
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
         self.check_thread()
 
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
+        self._pending_deferred_cache.pop(key, None)
 
         # XXX: why don't we invalidate the entry in `self.cache` yet?
 
-        # we can save a whole load of effort if the deferred is ready.
-        if value.called:
-            result = value.result
-            if not isinstance(result, failure.Failure):
-                self.cache.set(key, cast(VT, result), callbacks)
-            return value
-
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
         # and add callbacks to add it to the cache properly later.
+        entry = CacheEntrySingle[KT, VT](value)
+        entry.add_invalidation_callback(key, callback)
+        self._pending_deferred_cache[key] = entry
+        deferred = entry.deferred(key).addCallbacks(
+            self._completed_callback,
+            self._error_callback,
+            callbackArgs=(entry, key),
+            errbackArgs=(entry, key),
+        )
 
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+        # we return a new Deferred which will be called before any subsequent observers.
+        return deferred
 
-        self._pending_deferred_cache[key] = entry
+    def start_bulk_input(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> "CacheMultipleEntries[KT, VT]":
+        """Bulk set API for use when fetching multiple keys at once from the DB.
 
-        def compare_and_pop() -> bool:
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result: VT) -> None:
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail: Failure) -> None:
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
+        Called *before* starting the fetch from the DB, and the caller *must*
+        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+        """
 
-        # we return a new Deferred which will be called before any subsequent observers.
-        return observable.observe()
+        entry = CacheMultipleEntries[KT, VT]()
+        entry.add_global_invalidation_callback(callback)
+
+        for key in keys:
+            self._pending_deferred_cache[key] = entry
+
+        return entry
+
+    def _completed_callback(
+        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+    ) -> VT:
+        """Called when a deferred is completed."""
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return value
+
+        self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+        return value
+
+    def _error_callback(
+        self,
+        failure: Failure,
+        entry: "CacheEntry[KT, VT]",
+        key: KT,
+    ) -> Failure:
+        """Called when a deferred errors."""
+
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return failure
+
+        for cb in entry.get_invalidation_callbacks(key):
+            cb()
+
+        return failure
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
     ) -> None:
-        callbacks = [callback] if callback else []
+        callbacks = (callback,) if callback else ()
         self.cache.set(key, value, callbacks=callbacks)
+        self._pending_deferred_cache.pop(key, None)
 
     def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
+        # _pending_deferred_cache, which will (a) stop it being returned for
+        # future queries and (b) stop it being persisted as a proper entry
         # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
         if entry:
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
             for entry in iterate_tree_cache_entry(entry):
-                entry.invalidate()
+                for cb in entry.get_invalidation_callbacks(key):
+                    cb()
 
     def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
+        for key, entry in self._pending_deferred_cache.items():
+            for cb in entry.get_invalidation_callbacks(key):
+                cb()
+
         self._pending_deferred_cache.clear()
 
 
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+    """Abstract class for entries in `DeferredCache[KT, VT]`"""
 
-    def __init__(
-        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
-    ):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self) -> None:
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
+    @abc.abstractmethod
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        """Get a deferred that a caller can wait on to get the value at the
+        given key"""
+        ...
+
+    @abc.abstractmethod
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add an invalidation callback"""
+        ...
+
+    @abc.abstractmethod
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        """Get all invalidation callbacks"""
+        ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+    """An implementation of `CacheEntry` wrapping a deferred that results in a
+    single cache entry.
+    """
+
+    __slots__ = ["_deferred", "_callbacks"]
+
+    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+        self._callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        return self._deferred.observe()
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+    """Cache entry that is used for bulk lookups and insertions."""
+
+    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+    def __init__(self) -> None:
+        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+        self._global_callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        if not self._deferred:
+            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+        return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.setdefault(key, set()).add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks.get(key, set()) | self._global_callbacks
+
+    def add_global_invalidation_callback(
+        self, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add a callback for when any keys get invalidated."""
+        if callback is None:
+            return
+
+        self._global_callbacks.add(callback)
+
+    def complete_bulk(
+        self,
+        cache: DeferredCache[KT, VT],
+        result: Dict[KT, VT],
+    ) -> None:
+        """Called when there is a result"""
+        for key, value in result.items():
+            cache._completed_callback(value, self, key)
+
+        if self._deferred:
+            self._deferred.callback(result)
+
+    def error_bulk(
+        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+    ) -> None:
+        """Called when bulk lookup failed."""
+        for key in keys:
+            cache._error_callback(failure, self, key)
+
+        if self._deferred:
+            self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +571,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +582,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +628,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
+from typing_extensions import Literal
 
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 logger = logging.getLogger(__name__)
 
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DictionaryEntry:  # should be: Generic[DKT, DV].
     """Returned when getting an entry from the cache
 
+    If `full` is true then `known_absent` will be the empty set.
+
     Attributes:
         full: Whether the cache has the full or dict or just some keys.
             If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
         return len(self.value)
 
 
+class _FullCacheKey(enum.Enum):
+    """The key we use to cache the full dict."""
+
+    KEY = object()
+
+
 class _Sentinel(enum.Enum):
     # defining a sentinel in this way allows mypy to correctly handle the
     # type of a dictionary lookup.
     sentinel = object()
 
 
+class _PerKeyValue(Generic[DV]):
+    """The cached value of a dictionary key. If `value` is the sentinel,
+    indicates that the requested key is known to *not* be in the full dict.
+    """
+
+    __slots__ = ["value"]
+
+    def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+        self.value = value
+
+    def __len__(self) -> int:
+        # We add a `__len__` implementation as we use this class in a cache
+        # where the values are variable length.
+        return 1
+
+
 class DictionaryCache(Generic[KT, DKT, DV]):
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
+
+    This cache has two levels of key. First there is the "cache key" (of type
+    `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+    type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+    example, it might look like:
+
+       {
+           1: { 1: "a", 2: "b" },
+           2: { 1: "c" },
+       }
+
+    It is possible to look up either individual dict keys, or the *complete*
+    dict for a given cache key.
+
+    Each dict item, and the complete dict is treated as a separate LRU
+    entry for the purpose of cache expiry. For example, given:
+        dict_cache.get(1, None)  -> DictionaryEntry({1: "a", 2: "b"})
+        dict_cache.get(1, [1])  -> DictionaryEntry({1: "a"})
+        dict_cache.get(1, [2])  -> DictionaryEntry({2: "b"})
+
+    ... then the cache entry for the complete dict will expire first,
+    followed by the cache entry for the '1' dict key, and finally that
+    for the '2' dict key.
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
-            max_size=max_entries, cache_name=name, size_callback=len
+        # We use a single LruCache to store two different types of entries:
+        #   1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+        #      the key doesn't exist in the dict); and
+        #   2. Map from (key, _FullCacheKey.KEY) -> full dict.
+        #
+        # The former is used when explicit keys of the dictionary are looked up,
+        # and the latter when the full dictionary is requested.
+        #
+        # If when explicit keys are requested and not in the cache, we then look
+        # to see if we have the full dict and use that if we do. If found in the
+        # full dict each key is added into the cache.
+        #
+        # This set up allows the `LruCache` to prune the full dict entries if
+        # they haven't been used in a while, even when there have been recent
+        # queries for subsets of the dict.
+        #
+        # Typing:
+        #     * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+        #     * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+        self.cache: LruCache[
+            Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+            Union[_PerKeyValue, Dict[DKT, DV]],
+        ] = LruCache(
+            max_size=max_entries,
+            cache_name=name,
+            cache_type=TreeCache,
+            size_callback=len,
         )
 
         self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         Args:
             key
             dict_keys: If given a set of keys then return only those keys
-                that exist in the cache.
+                that exist in the cache. If None then returns the full dict
+                if it is in the cache.
 
         Returns:
-            DictionaryEntry
+            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+            will contain include the keys that are in the cache. If None then
+            will either return the full dict if in the cache, or the empty
+            dict (with `full` set to False) if it isn't.
         """
-        entry = self.cache.get(key, _Sentinel.sentinel)
-        if entry is not _Sentinel.sentinel:
-            if dict_keys is None:
-                return DictionaryEntry(
-                    entry.full, entry.known_absent, dict(entry.value)
-                )
+        if dict_keys is None:
+            # The caller wants the full set of dictionary keys for this cache key
+            return self._get_full_dict(key)
+
+        # We are being asked for a subset of keys.
+
+        # First go and check for each requested dict key in the cache, tracking
+        # which we couldn't find.
+        values = {}
+        known_absent = set()
+        missing = []
+        for dict_key in dict_keys:
+            entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+            if entry is _Sentinel.sentinel:
+                missing.append(dict_key)
+                continue
+
+            assert isinstance(entry, _PerKeyValue)
+
+            if entry.value is _Sentinel.sentinel:
+                known_absent.add(dict_key)
             else:
-                return DictionaryEntry(
-                    entry.full,
-                    entry.known_absent,
-                    {k: entry.value[k] for k in dict_keys if k in entry.value},
-                )
+                values[dict_key] = entry.value
+
+        # If we found everything we can return immediately.
+        if not missing:
+            return DictionaryEntry(False, known_absent, values)
+
+        # We are missing some keys, so check if we happen to have the full dict in
+        # the cache.
+        #
+        # We don't update the last access time for this cache fetch, as we
+        # aren't explicitly interested in the full dict and so we don't want
+        # requests for explicit dict keys to keep the full dict in the cache.
+        entry = self.cache.get(
+            (key, _FullCacheKey.KEY),
+            _Sentinel.sentinel,
+            update_last_access=False,
+        )
+        if entry is _Sentinel.sentinel:
+            # Not in the cache, return the subset of keys we found.
+            return DictionaryEntry(False, known_absent, values)
+
+        # We have the full dict!
+        assert isinstance(entry, dict)
+
+        for dict_key in missing:
+            # We explicitly add each dict key to the cache, so that cache hit
+            # rates and LRU times for each key can be tracked separately.
+            value = entry.get(dict_key, _Sentinel.sentinel)  # type: ignore[arg-type]
+            self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+            if value is not _Sentinel.sentinel:
+                values[dict_key] = value
+
+        return DictionaryEntry(True, set(), values)
+
+    def _get_full_dict(
+        self,
+        key: KT,
+    ) -> DictionaryEntry:
+        """Fetch the full dict for the given key."""
+
+        # First we check if we have cached the full dict.
+        entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
+            assert isinstance(entry, dict)
+            return DictionaryEntry(True, set(), entry)
 
         return DictionaryEntry(False, set(), {})
 
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(key, None)
+
+        # We want to drop all information about the dict for the given key, so
+        # we use `del_multi` to delete it all in one go.
+        #
+        # We ignore the type error here: `del_multi` accepts a truncated key
+        # (when the key type is a tuple).
+        self.cache.del_multi((key,))  # type: ignore[arg-type]
 
     def invalidate_all(self) -> None:
         self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         value: Dict[DKT, DV],
         fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
-        """Updates the entry in the cache
+        """Updates the entry in the cache.
+
+        Note: This does *not* invalidate any existing entries for the `key`.
+        In particular, if we add an entry for the cached "full dict" with
+        `fetched_keys=None`, existing entries for individual dict keys are
+        not invalidated. Likewise, adding entries for individual keys does
+        not invalidate any cached value for the full dict.
+
+        In other words: if the underlying data is *changed*, the cache must
+        be explicitly invalidated via `.invalidate()`.
 
         Args:
             sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
             if fetched_keys is None:
-                self._insert(key, value, set())
+                self.cache[(key, _FullCacheKey.KEY)] = value
             else:
-                self._update_or_insert(key, value, fetched_keys)
+                self._update_subset(key, value, fetched_keys)
 
-    def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+    def _update_subset(
+        self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
     ) -> None:
-        # We pop and reinsert as we need to tell the cache the size may have
-        # changed
+        """Add the given dictionary values as explicit keys in the cache.
+
+        Args:
+            key: top-level cache key
+            value: The dictionary with all the values that we should cache
+            fetched_keys: The full set of dict keys that were looked up. Any keys
+                here not in `value` should be marked as "known absent".
+        """
+
+        for dict_key, dict_value in value.items():
+            self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
 
-        entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
-        entry.value.update(value)
-        entry.known_absent.update(known_absent)
-        self.cache[key] = entry
+        for dict_key in fetched_keys:
+            if dict_key in value:
+                continue
 
-    def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
-        self.cache[key] = DictionaryEntry(True, known_absent, value)
+            self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a3b60578e3..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     List,
     Optional,
+    Tuple,
     Type,
     TypeVar,
     Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+    TreeCache,
+    iterate_tree_cache_entry,
+    iterate_tree_cache_items,
+)
 from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
@@ -109,7 +115,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 @wrap_as_background_process("LruCache._expire_old_entries")
 async def _expire_old_entries(
-    clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
+    clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
 ) -> None:
     """Walks the global cache list to find cache entries that haven't been
     accessed in the given number of seconds, or if a given memory threshold has been breached.
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
             default: Literal[None] = None,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Optional[VT]:
             ...
 
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
             default: T,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Union[T, VT]:
             ...
 
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
+            update_last_access: bool = True,
         ) -> Union[None, T, VT]:
+            """Look up a key in the cache
+
+            Args:
+                key
+                default
+                callbacks: A collection of callbacks that will fire when the
+                    node is removed from the cache (either due to invalidation
+                    or expiry).
+                update_metrics: Whether to update the hit rate metrics
+                update_last_access: Whether to update the last access metrics
+                    on a node if successfully fetched. These metrics are used
+                    to determine when to remove the node from the cache. Set
+                    to False if this fetch should *not* prevent a node from
+                    being expired.
+            """
             node = cache.get(key, None)
             if node is not None:
-                move_node_to_front(node)
+                if update_last_access:
+                    move_node_to_front(node)
                 node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
                     metrics.inc_misses()
                 return default
 
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: Literal[None] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: T,
+            update_metrics: bool = True,
+        ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @synchronized
+        def cache_get_multi(
+            key: tuple,
+            default: Optional[T] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+            """Returns a generator yielding all entries under the given key.
+
+            Can only be used if backed by a tree cache.
+
+            Example:
+
+                cache = LruCache(10, cache_type=TreeCache)
+                cache[(1, 1)] = "a"
+                cache[(1, 2)] = "b"
+                cache[(2, 1)] = "c"
+
+                items = cache.get_multi((1,))
+                assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+            Returns:
+                Either default if the key doesn't exist, or a generator of the
+                key/value pairs.
+            """
+
+            assert isinstance(cache, TreeCache)
+
+            node = cache.get(key, None)
+            if node is not None:
+                if update_metrics and metrics:
+                    metrics.inc_hits()
+
+                # We store entries in the `TreeCache` with values of type `_Node`,
+                # which we need to unwrap.
+                return (
+                    (full_key, lru_node.value)
+                    for full_key, lru_node in iterate_tree_cache_items(key, node)
+                )
+            else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
+                return default
+
         @synchronized
         def cache_set(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
         self.setdefault = cache_set_default
         self.pop = cache_pop
         self.del_multi = cache_del_multi
+        if cache_type is TreeCache:
+            self.get_multi = cache_get_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
         self.invalidate = cache_del_multi
@@ -730,3 +816,58 @@ class LruCache(Generic[KT, VT]):
         # This happens e.g. in the sync code where we have an expiring cache of
         # lru caches.
         self.clear()
+
+
+class AsyncLruCache(Generic[KT, VT]):
+    """
+    An asynchronous wrapper around a subset of the LruCache API.
+
+    On its own this doesn't change the behaviour but allows subclasses that
+    utilize external cache systems that require await behaviour to be created.
+    """
+
+    def __init__(self, *args, **kwargs):  # type: ignore
+        self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
+
+    async def get(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def get_external(
+        self,
+        key: KT,
+        default: Optional[T] = None,
+        update_metrics: bool = True,
+    ) -> Optional[VT]:
+        # This method should fetch from any configured external cache, in this case noop.
+        return None
+
+    def get_local(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def set(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    def set_local(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    async def invalidate(self, key: KT) -> None:
+        # This method should invalidate any external cache and then invalidate the LruCache.
+        return self._lru_cache.invalidate(key)
+
+    def invalidate_local(self, key: KT) -> None:
+        """Remove an entry from the local cache
+
+        This variant of `invalidate` is useful if we know that the external
+        cache has already been invalidated.
+        """
+        return self._lru_cache.invalidate(key)
+
+    async def contains(self, key: KT) -> bool:
+        return self._lru_cache.contains(key)
+
+    async def clear(self) -> None:
+        self._lru_cache.clear()
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
         self.size += 1
 
     def get(self, key, default=None):
+        """When `key` is a full key, fetches the value for the given key (if
+        any).
+
+        If `key` is only a partial key (i.e. a truncated tuple) then returns a
+        `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+        functions to iterate over all entries in the cache with keys that start
+        with the given partial key.
+        """
+
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
@@ -126,6 +135,9 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
+    def items(self):
+        return iterate_tree_cache_items((), self.root)
+
     def __len__(self) -> int:
         return self.size
 
@@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d):
             yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
+
+
+def iterate_tree_cache_items(key, value):
+    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+    can contain dicts.
+
+    The provided key is a tuple that will get prepended to the returned keys.
+
+    Example:
+
+        cache = TreeCache()
+        cache[(1, 1)] = "a"
+        cache[(1, 2)] = "b"
+        cache[(2, 1)] = "c"
+
+        tree_node = cache.get((1,))
+
+        items = iterate_tree_cache_items((1,), tree_node)
+        assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+    Returns:
+        A generator yielding key/value pairs.
+    """
+    if isinstance(value, TreeCacheNode):
+        for sub_key, sub_value in value.items():
+            yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+    else:
+        # we've reached a leaf of the tree.
+        yield key, value
diff --git a/synapse/util/cancellation.py b/synapse/util/cancellation.py
new file mode 100644
index 0000000000..472d2e3aeb
--- /dev/null
+++ b/synapse/util/cancellation.py
@@ -0,0 +1,56 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, TypeVar
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def cancellable(function: F) -> F:
+    """Marks a function as cancellable.
+
+    Servlet methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    Although this annotation is particularly useful for servlet methods, it's also
+    useful for intermediate functions, where it documents the fact that the function has
+    been audited for cancellation safety and needs to preserve that.
+    This then simplifies auditing new functions that call those same intermediate
+    functions.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new function, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    See the documentation page on Cancellation for more information.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+
+    function.cancellable = True  # type: ignore[attr-defined]
+    return function
+
+
+def is_function_cancellable(function: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(function, "cancellable", False)
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 84e4f6ff55..df77edcce2 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -17,8 +17,14 @@
 
 from typing import Callable, Optional
 
+import attr
 import pymacaroons
 from pymacaroons.exceptions import MacaroonVerificationFailedException
+from typing_extensions import Literal
+
+from synapse.util import Clock, stringutils
+
+MacaroonType = Literal["access", "delete_pusher", "session", "login"]
 
 
 def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -86,3 +92,305 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N
         return time_msec < expiry
 
     v.satisfy_general(verify_expiry_caveat)
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
+
+    idp_id: str
+    """The Identity Provider being used"""
+
+    nonce: str
+    """The `nonce` parameter passed to the OIDC provider."""
+
+    client_redirect_url: str
+    """The URL the client gave when it initiated the flow. ("" if this is a UI Auth)"""
+
+    ui_auth_session_id: str
+    """The session ID of the ongoing UI Auth ("" if this is a login)"""
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LoginTokenAttributes:
+    """Data we store in a short-term login token"""
+
+    user_id: str
+
+    auth_provider_id: str
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id: Optional[str]
+    """The session ID advertised by the SSO Identity Provider."""
+
+
+class MacaroonGenerator:
+    def __init__(self, clock: Clock, location: str, secret_key: bytes):
+        self._clock = clock
+        self._location = location
+        self._secret_key = secret_key
+
+    def generate_guest_access_token(self, user_id: str) -> str:
+        """Generate a guest access token for the given user ID
+
+        Args:
+            user_id: The user ID for which the guest token should be generated.
+
+        Returns:
+            A signed access token for that guest user.
+        """
+        nonce = stringutils.random_string_with_symbols(16)
+        macaroon = self._generate_base_macaroon("access")
+        macaroon.add_first_party_caveat(f"user_id = {user_id}")
+        macaroon.add_first_party_caveat(f"nonce = {nonce}")
+        macaroon.add_first_party_caveat("guest = true")
+        return macaroon.serialize()
+
+    def generate_delete_pusher_token(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> str:
+        """Generate a signed token used for unsubscribing from email notifications
+
+        Args:
+            user_id: The user for which this token will be valid.
+            app_id: The app_id for this pusher.
+            pushkey: The unique identifier of this pusher.
+
+        Returns:
+            A signed token which can be used in unsubscribe links.
+        """
+        macaroon = self._generate_base_macaroon("delete_pusher")
+        macaroon.add_first_party_caveat(f"user_id = {user_id}")
+        macaroon.add_first_party_caveat(f"app_id = {app_id}")
+        macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
+        return macaroon.serialize()
+
+    def generate_short_term_login_token(
+        self,
+        user_id: str,
+        auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
+        duration_in_ms: int = (2 * 60 * 1000),
+    ) -> str:
+        """Generate a short-term login token used during SSO logins
+
+        Args:
+            user_id: The user for which the token is valid.
+            auth_provider_id: The SSO IdP the user used.
+            auth_provider_session_id: The session ID got during login from the SSO IdP.
+
+        Returns:
+            A signed token valid for using as a ``m.login.token`` token.
+        """
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon = self._generate_base_macaroon("login")
+        macaroon.add_first_party_caveat(f"user_id = {user_id}")
+        macaroon.add_first_party_caveat(f"time < {expiry}")
+        macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                f"auth_provider_session_id = {auth_provider_session_id}"
+            )
+        return macaroon.serialize()
+
+    def generate_oidc_session_token(
+        self,
+        state: str,
+        session_data: OidcSessionData,
+        duration_in_ms: int = (60 * 60 * 1000),
+    ) -> str:
+        """Generates a signed token storing data about an OIDC session.
+
+        When Synapse initiates an authorization flow, it creates a random state
+        and a random nonce. Those parameters are given to the provider and
+        should be verified when the client comes back from the provider.
+        It is also used to store the client_redirect_url, which is used to
+        complete the SSO login flow.
+
+        Args:
+            state: The ``state`` parameter passed to the OIDC provider.
+            session_data: data to include in the session token.
+            duration_in_ms: An optional duration for the token in milliseconds.
+                Defaults to an hour.
+
+        Returns:
+            A signed macaroon token with the session information.
+        """
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon = self._generate_base_macaroon("session")
+        macaroon.add_first_party_caveat(f"state = {state}")
+        macaroon.add_first_party_caveat(f"idp_id = {session_data.idp_id}")
+        macaroon.add_first_party_caveat(f"nonce = {session_data.nonce}")
+        macaroon.add_first_party_caveat(
+            f"client_redirect_url = {session_data.client_redirect_url}"
+        )
+        macaroon.add_first_party_caveat(
+            f"ui_auth_session_id = {session_data.ui_auth_session_id}"
+        )
+        macaroon.add_first_party_caveat(f"time < {expiry}")
+
+        return macaroon.serialize()
+
+    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+        """Verify a short-term-login macaroon
+
+        Checks that the given token is a valid, unexpired short-term-login token
+        minted by this server.
+
+        Args:
+            token: The login token to verify.
+
+        Returns:
+            A set of attributes carried by this token, including the
+            ``user_id`` and informations about the SSO IDP used during that
+            login.
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+
+        v = self._base_verifier("login")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
+        satisfy_expiry(v, self._clock.time_msec)
+        v.verify(macaroon, self._secret_key)
+
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+    def verify_guest_token(self, token: str) -> str:
+        """Verify a guest access token macaroon
+
+        Checks that the given token is a valid, unexpired guest access token
+        minted by this server.
+
+        Args:
+            token: The access token to verify.
+
+        Returns:
+            The ``user_id`` that this token is valid for.
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+
+        # At some point, Synapse would generate macaroons without the "guest"
+        # caveat for regular users. Because of how macaroon verification works,
+        # to avoid validating those as guest tokens, we explicitely verify if
+        # the macaroon includes the "guest = true" caveat.
+        is_guest = any(
+            (caveat.caveat_id == "guest = true" for caveat in macaroon.caveats)
+        )
+
+        if not is_guest:
+            raise MacaroonVerificationFailedException("Macaroon is not a guest token")
+
+        v = self._base_verifier("access")
+        v.satisfy_exact("guest = true")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        satisfy_expiry(v, self._clock.time_msec)
+        v.verify(macaroon, self._secret_key)
+
+        return user_id
+
+    def verify_delete_pusher_token(self, token: str, app_id: str, pushkey: str) -> str:
+        """Verify a token from an email unsubscribe link
+
+        Args:
+            token: The token to verify.
+            app_id: The app_id of the pusher to delete.
+            pushkey: The unique identifier of the pusher to delete.
+
+        Return:
+            The ``user_id`` for which this token is valid.
+
+        Raises:
+            MacaroonVerificationFailedException if the verification failed
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+        user_id = get_value_from_macaroon(macaroon, "user_id")
+
+        v = self._base_verifier("delete_pusher")
+        v.satisfy_exact(f"app_id = {app_id}")
+        v.satisfy_exact(f"pushkey = {pushkey}")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.verify(macaroon, self._secret_key)
+
+        return user_id
+
+    def verify_oidc_session_token(self, session: bytes, state: str) -> OidcSessionData:
+        """Verifies and extract an OIDC session token.
+
+        This verifies that a given session token was issued by this homeserver
+        and extract the nonce and client_redirect_url caveats.
+
+        Args:
+            session: The session token to verify
+            state: The state the OIDC provider gave back
+
+        Returns:
+            The data extracted from the session cookie
+
+        Raises:
+            KeyError if an expected caveat is missing from the macaroon.
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(session)
+
+        v = self._base_verifier("session")
+        v.satisfy_exact(f"state = {state}")
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
+        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+        satisfy_expiry(v, self._clock.time_msec)
+
+        v.verify(macaroon, self._secret_key)
+
+        # Extract the session data from the token.
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        idp_id = get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+        ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
+        return OidcSessionData(
+            nonce=nonce,
+            idp_id=idp_id,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
+
+    def _generate_base_macaroon(self, type: MacaroonType) -> pymacaroons.Macaroon:
+        macaroon = pymacaroons.Macaroon(
+            location=self._location,
+            identifier="key",
+            key=self._secret_key,
+        )
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat(f"type = {type}")
+        return macaroon
+
+    def _base_verifier(self, type: MacaroonType) -> pymacaroons.Verifier:
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact(f"type = {type}")
+        return v
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index dfe628c97e..9f64fed0d7 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -15,18 +15,35 @@
 import collections
 import contextlib
 import logging
+import threading
 import typing
-from typing import Any, DefaultDict, Iterator, List, Set
+from typing import (
+    Any,
+    Callable,
+    DefaultDict,
+    Dict,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
+
+from prometheus_client.core import Counter
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 
 from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
 from synapse.util import Clock
 
 if typing.TYPE_CHECKING:
@@ -35,15 +52,127 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter(
+    "synapse_rate_limit_sleep",
+    "Number of requests slept by the rate limiter",
+    ["rate_limiter_name"],
+)
+rate_limit_reject_counter = Counter(
+    "synapse_rate_limit_reject",
+    "Number of requests rejected by the rate limiter",
+    ["rate_limiter_name"],
+)
+queue_wait_timer = Histogram(
+    "synapse_rate_limit_queue_wait_time_seconds",
+    "Amount of time spent waiting for the rate limiter to let our request through.",
+    ["rate_limiter_name"],
+    buckets=(
+        0.005,
+        0.01,
+        0.025,
+        0.05,
+        0.1,
+        0.25,
+        0.5,
+        0.75,
+        1.0,
+        2.5,
+        5.0,
+        10.0,
+        20.0,
+        "+Inf",
+    ),
+)
+
+
+_rate_limiter_instances: Set["FederationRateLimiter"] = set()
+# Protects the _rate_limiter_instances set from concurrent access
+_rate_limiter_instances_lock = threading.Lock()
+
+
+def _get_counts_from_rate_limiter_instance(
+    count_func: Callable[["FederationRateLimiter"], int]
+) -> Mapping[Tuple[str, ...], int]:
+    """Returns a count of something (slept/rejected hosts) by (metrics_name)"""
+    # Cast to a list to prevent it changing while the Prometheus
+    # thread is collecting metrics
+    with _rate_limiter_instances_lock:
+        rate_limiter_instances = list(_rate_limiter_instances)
+
+    # Map from (metrics_name,) -> int, the number of something like slept hosts
+    # or rejected hosts. The key type is Tuple[str], but we leave the length
+    # unspecified for compatability with LaterGauge's annotations.
+    counts: Dict[Tuple[str, ...], int] = {}
+    for rate_limiter_instance in rate_limiter_instances:
+        # Only track metrics if they provided a `metrics_name` to
+        # differentiate this instance of the rate limiter.
+        if rate_limiter_instance.metrics_name:
+            key = (rate_limiter_instance.metrics_name,)
+            counts[key] = count_func(rate_limiter_instance)
+
+    return counts
+
+
+# We track the number of affected hosts per time-period so we can
+# differentiate one really noisy homeserver from a general
+# ratelimit tuning problem across the federation.
+LaterGauge(
+    "synapse_rate_limit_sleep_affected_hosts",
+    "Number of hosts that had requests put to sleep",
+    ["rate_limiter_name"],
+    lambda: _get_counts_from_rate_limiter_instance(
+        lambda rate_limiter_instance: sum(
+            ratelimiter.should_sleep()
+            for ratelimiter in rate_limiter_instance.ratelimiters.values()
+        )
+    ),
+)
+LaterGauge(
+    "synapse_rate_limit_reject_affected_hosts",
+    "Number of hosts that had requests rejected",
+    ["rate_limiter_name"],
+    lambda: _get_counts_from_rate_limiter_instance(
+        lambda rate_limiter_instance: sum(
+            ratelimiter.should_reject()
+            for ratelimiter in rate_limiter_instance.ratelimiters.values()
+        )
+    ),
+)
+
+
 class FederationRateLimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    """Used to rate limit request per-host."""
+
+    def __init__(
+        self,
+        clock: Clock,
+        config: FederationRatelimitSettings,
+        metrics_name: Optional[str] = None,
+    ):
+        """
+        Args:
+            clock
+            config
+            metrics_name: The name of the rate limiter so we can differentiate it
+                from the rest in the metrics. If `None`, we don't track metrics
+                for this rate limiter.
+
+        """
+        self.metrics_name = metrics_name
+
         def new_limiter() -> "_PerHostRatelimiter":
-            return _PerHostRatelimiter(clock=clock, config=config)
+            return _PerHostRatelimiter(
+                clock=clock, config=config, metrics_name=metrics_name
+            )
 
         self.ratelimiters: DefaultDict[
             str, "_PerHostRatelimiter"
         ] = collections.defaultdict(new_limiter)
 
+        with _rate_limiter_instances_lock:
+            _rate_limiter_instances.add(self)
+
     def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
         """Used to ratelimit an incoming request from a given host
 
@@ -59,17 +188,27 @@ class FederationRateLimiter:
         Returns:
             context manager which returns a deferred.
         """
-        return self.ratelimiters[host].ratelimit()
+        return self.ratelimiters[host].ratelimit(host)
 
 
 class _PerHostRatelimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    def __init__(
+        self,
+        clock: Clock,
+        config: FederationRatelimitSettings,
+        metrics_name: Optional[str] = None,
+    ):
         """
         Args:
             clock
             config
+            metrics_name: The name of the rate limiter so we can differentiate it
+                from the rest in the metrics. If `None`, we don't track metrics
+                for this rate limiter.
+                from the rest in the metrics
         """
         self.clock = clock
+        self.metrics_name = metrics_name
 
         self.window_size = config.window_size
         self.sleep_limit = config.sleep_limit
@@ -94,19 +233,45 @@ class _PerHostRatelimiter:
         self.request_times: List[int] = []
 
     @contextlib.contextmanager
-    def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+    def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
         # `contextlib.contextmanager` takes a generator and turns it into a
         # context manager. The generator should only yield once with a value
         # to be returned by manager.
         # Exceptions will be reraised at the yield.
 
+        self.host = host
+
         request_id = object()
-        ret = self._on_enter(request_id)
+        # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+        # type-checking, but we'd need Twisted >= 21.2.
+        ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
         try:
             yield ret
         finally:
             self._on_exit(request_id)
 
+    def should_reject(self) -> bool:
+        """
+        Whether to reject the request if we already have too many queued up
+        (either sleeping or in the ready queue).
+        """
+        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+        return queue_size > self.reject_limit
+
+    def should_sleep(self) -> bool:
+        """
+        Whether to sleep the request if we already have too many requests coming
+        through within the window.
+        """
+        return len(self.request_times) > self.sleep_limit
+
+    async def _on_enter_with_tracing(self, request_id: object) -> None:
+        maybe_metrics_cm: ContextManager = contextlib.nullcontext()
+        if self.metrics_name:
+            maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
+        with start_active_span("ratelimit wait"), maybe_metrics_cm:
+            await self._on_enter(request_id)
+
     def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
         time_now = self.clock.time_msec()
 
@@ -117,8 +282,10 @@ class _PerHostRatelimiter:
 
         # reject the request if we already have too many queued up (either
         # sleeping or in the ready queue).
-        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
-        if queue_size > self.reject_limit:
+        if self.should_reject():
+            logger.debug("Ratelimiter(%s): rejecting request", self.host)
+            if self.metrics_name:
+                rate_limit_reject_counter.labels(self.metrics_name).inc()
             raise LimitExceededError(
                 retry_after_ms=int(self.window_size / self.sleep_limit)
             )
@@ -130,7 +297,8 @@ class _PerHostRatelimiter:
                 queue_defer: defer.Deferred[None] = defer.Deferred()
                 self.ready_request_queue[request_id] = queue_defer
                 logger.info(
-                    "Ratelimiter: queueing request (queue now %i items)",
+                    "Ratelimiter(%s): queueing request (queue now %i items)",
+                    self.host,
                     len(self.ready_request_queue),
                 )
 
@@ -139,19 +307,29 @@ class _PerHostRatelimiter:
                 return defer.succeed(None)
 
         logger.debug(
-            "Ratelimit [%s]: len(self.request_times)=%d",
+            "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+            self.host,
             id(request_id),
             len(self.request_times),
         )
 
-        if len(self.request_times) > self.sleep_limit:
-            logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+        if self.should_sleep():
+            logger.debug(
+                "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+                self.host,
+                id(request_id),
+                self.sleep_sec,
+            )
+            if self.metrics_name:
+                rate_limit_sleep_counter.labels(self.metrics_name).inc()
             ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
 
             self.sleeping_requests.add(request_id)
 
             def on_wait_finished(_: Any) -> "defer.Deferred[None]":
-                logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+                logger.debug(
+                    "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+                )
                 self.sleeping_requests.discard(request_id)
                 queue_defer = queue_request()
                 return queue_defer
@@ -161,7 +339,9 @@ class _PerHostRatelimiter:
             ret_defer = queue_request()
 
         def on_start(r: object) -> object:
-            logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+            logger.debug(
+                "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+            )
             self.current_processing.add(request_id)
             return r
 
@@ -183,7 +363,7 @@ class _PerHostRatelimiter:
         return make_deferred_yieldable(ret_defer)
 
     def _on_exit(self, request_id: object) -> None:
-        logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+        logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
         self.current_processing.discard(request_id)
         try:
             # start processing the next item on the queue.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 8aaa8c709f..c810a05907 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -13,16 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from enum import Enum, auto
 from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
 
+import attr
 from typing_extensions import Final
 
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
+from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
+from synapse.storage.databases.main import DataStore
 from synapse.storage.state import StateFilter
 from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -46,6 +52,7 @@ MEMBERSHIP_PRIORITY = (
 _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
 
 
+@trace
 async def filter_events_for_client(
     storage: StorageControllers,
     user_id: str,
@@ -66,8 +73,8 @@ async def filter_events_for_client(
           * the user is not currently a member of the room, and:
           * the user has not been a member of the room since the given
             events
-        always_include_ids: set of event ids to specifically
-            include (unless sender is ignored)
+        always_include_ids: set of event ids to specifically include, if present
+            in events (unless sender is ignored)
         filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.
@@ -102,165 +109,410 @@ async def filter_events_for_client(
             ] = await storage.main.get_retention_policy_for_room(room_id)
 
     def allowed(event: EventBase) -> Optional[EventBase]:
-        """
-        Args:
-            event: event to check
-
-        Returns:
-           None if the user cannot see this event at all
-
-           a redacted copy of the event if they can only see a redacted
-           version
-
-           the original event if they can see it as normal.
-        """
-        # Only run some checks if these events aren't about to be sent to clients. This is
-        # because, if this is not the case, we're probably only checking if the users can
-        # see events in the room at that point in the DAG, and that shouldn't be decided
-        # on those checks.
-        if filter_send_to_client:
-            if event.type == EventTypes.Dummy:
-                return None
-
-            if not event.is_state() and event.sender in ignore_list:
-                return None
-
-            # Until MSC2261 has landed we can't redact malicious alias events, so for
-            # now we temporarily filter out m.room.aliases entirely to mitigate
-            # abuse, while we spec a better solution to advertising aliases
-            # on rooms.
-            if event.type == EventTypes.Aliases:
-                return None
-
-            # Don't try to apply the room's retention policy if the event is a state
-            # event, as MSC1763 states that retention is only considered for non-state
-            # events.
-            if not event.is_state():
-                retention_policy = retention_policies[event.room_id]
-                max_lifetime = retention_policy.max_lifetime
-
-                if max_lifetime is not None:
-                    oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
-
-                    if event.origin_server_ts < oldest_allowed_ts:
-                        return None
-
-        if event.event_id in always_include_ids:
-            return event
+        return _check_client_allowed_to_see_event(
+            user_id=user_id,
+            event=event,
+            clock=storage.main.clock,
+            filter_send_to_client=filter_send_to_client,
+            sender_ignored=event.sender in ignore_list,
+            always_include_ids=always_include_ids,
+            retention_policy=retention_policies[room_id],
+            state=event_id_to_state.get(event.event_id),
+            is_peeking=is_peeking,
+            sender_erased=erased_senders.get(event.sender, False),
+        )
 
-        # we need to handle outliers separately, since we don't have the room state.
-        if event.internal_metadata.outlier:
-            # Normally these can't be seen by clients, but we make an exception for
-            # for out-of-band membership events (eg, incoming invites, or rejections of
-            # said invite) for the user themselves.
-            if event.type == EventTypes.Member and event.state_key == user_id:
-                logger.debug("Returning out-of-band-membership event %s", event)
-                return event
+    # Check each event: gives an iterable of None or (a potentially modified)
+    # EventBase.
+    filtered_events = map(allowed, events)
+
+    # Turn it into a list and remove None entries before returning.
+    return [ev for ev in filtered_events if ev]
 
-            return None
 
-        state = event_id_to_state[event.event_id]
+async def filter_event_for_clients_with_state(
+    store: DataStore,
+    user_ids: Collection[str],
+    event: EventBase,
+    context: EventContext,
+    is_peeking: bool = False,
+    filter_send_to_client: bool = True,
+) -> Collection[str]:
+    """
+    Checks to see if an event is visible to the users in the list at the time of
+    the event.
 
-        # get the room_visibility at the time of the event.
-        visibility = get_effective_room_visibility_from_state(state)
+    Note: This does *not* check if the sender of the event was erased.
 
-        # Always allow history visibility events on boundaries. This is done
-        # by setting the effective visibility to the least restrictive
-        # of the old vs new.
-        if event.type == EventTypes.RoomHistoryVisibility:
-            prev_content = event.unsigned.get("prev_content", {})
-            prev_visibility = prev_content.get("history_visibility", None)
+    Args:
+        store: databases
+        user_ids: user_ids to be checked
+        event: the event to be checked
+        context: EventContext for the event to be checked
+        is_peeking: Whether the users are peeking into the room, ie not
+            currently joined
+        filter_send_to_client: Whether we're checking an event that's going to be
+            sent to a client. This might not always be the case since this function can
+            also be called to check whether a user can see the state at a given point.
 
-            if prev_visibility not in VISIBILITY_PRIORITY:
-                prev_visibility = HistoryVisibility.SHARED
+    Returns:
+        Collection of user IDs for whom the event is visible
+    """
+    # None of the users should see the event if it is soft_failed
+    if event.internal_metadata.is_soft_failed():
+        return []
 
-            new_priority = VISIBILITY_PRIORITY.index(visibility)
-            old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
-            if old_priority < new_priority:
-                visibility = prev_visibility
+    # Make a set for all user IDs that haven't been filtered out by a check.
+    allowed_user_ids = set(user_ids)
 
-        # likewise, if the event is the user's own membership event, use
-        # the 'most joined' membership
-        membership = None
-        if event.type == EventTypes.Member and event.state_key == user_id:
-            membership = event.content.get("membership", None)
-            if membership not in MEMBERSHIP_PRIORITY:
-                membership = "leave"
-
-            prev_content = event.unsigned.get("prev_content", {})
-            prev_membership = prev_content.get("membership", None)
-            if prev_membership not in MEMBERSHIP_PRIORITY:
-                prev_membership = "leave"
-
-            # Always allow the user to see their own leave events, otherwise
-            # they won't see the room disappear if they reject the invite
-            #
-            # (Note this doesn't work for out-of-band invite rejections, which don't
-            # have prev_state populated. They are handled above in the outlier code.)
-            if membership == "leave" and (
-                prev_membership == "join" or prev_membership == "invite"
+    # Only run some checks if these events aren't about to be sent to clients. This is
+    # because, if this is not the case, we're probably only checking if the users can
+    # see events in the room at that point in the DAG, and that shouldn't be decided
+    # on those checks.
+    if filter_send_to_client:
+        ignored_by = await store.ignored_by(event.sender)
+        retention_policy = await store.get_retention_policy_for_room(event.room_id)
+
+        for user_id in user_ids:
+            if (
+                _check_filter_send_to_client(
+                    event,
+                    store.clock,
+                    retention_policy,
+                    sender_ignored=user_id in ignored_by,
+                )
+                == _CheckFilter.DENIED
             ):
-                return event
-
-            new_priority = MEMBERSHIP_PRIORITY.index(membership)
-            old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
-            if old_priority < new_priority:
-                membership = prev_membership
-
-        # otherwise, get the user's membership at the time of the event.
-        if membership is None:
-            membership_event = state.get((EventTypes.Member, user_id), None)
-            if membership_event:
-                membership = membership_event.membership
-
-        # if the user was a member of the room at the time of the event,
-        # they can see it.
-        if membership == Membership.JOIN:
-            return event
+                allowed_user_ids.discard(user_id)
+
+    if event.internal_metadata.outlier:
+        # Normally these can't be seen by clients, but we make an exception for
+        # for out-of-band membership events (eg, incoming invites, or rejections of
+        # said invite) for the user themselves.
+        if event.type == EventTypes.Member and event.state_key in allowed_user_ids:
+            logger.debug("Returning out-of-band-membership event %s", event)
+            return {event.state_key}
+
+        return set()
+
+    # First we get just the history visibility in case its shared/world-readable
+    # room.
+    visibility_state_map = await _get_state_map(
+        store, event, context, StateFilter.from_types([_HISTORY_VIS_KEY])
+    )
 
-        # otherwise, it depends on the room visibility.
+    visibility = get_effective_room_visibility_from_state(visibility_state_map)
+    if (
+        _check_history_visibility(event, visibility, is_peeking=is_peeking)
+        == _CheckVisibility.ALLOWED
+    ):
+        return allowed_user_ids
 
-        if visibility == HistoryVisibility.JOINED:
-            # we weren't a member at the time of the event, so we can't
-            # see this event.
-            return None
+    # The history visibility isn't lax, so we now need to fetch the membership
+    # events of all the users.
+
+    filter_list = []
+    for user_id in allowed_user_ids:
+        filter_list.append((EventTypes.Member, user_id))
+    filter_list.append((EventTypes.RoomHistoryVisibility, ""))
+
+    state_filter = StateFilter.from_types(filter_list)
+    state_map = await _get_state_map(store, event, context, state_filter)
+
+    # Now we check whether the membership allows each user to see the event.
+    return {
+        user_id
+        for user_id in allowed_user_ids
+        if _check_membership(user_id, event, visibility, state_map, is_peeking).allowed
+    }
+
+
+async def _get_state_map(
+    store: DataStore, event: EventBase, context: EventContext, state_filter: StateFilter
+) -> StateMap[EventBase]:
+    """Helper function for getting a `StateMap[EventBase]` from an `EventContext`"""
+    state_map = await context.get_prev_state_ids(state_filter)
+
+    # Use events rather than event ids as content from the events are needed in
+    # _check_visibility
+    event_map = await store.get_events(state_map.values(), get_prev_content=False)
+
+    updated_state_map = {}
+    for state_key, event_id in state_map.items():
+        state_event = event_map.get(event_id)
+        if state_event:
+            updated_state_map[state_key] = state_event
+
+    if event.is_state():
+        current_state_key = (event.type, event.state_key)
+        # Add current event to updated_state_map, we need to do this here as it
+        # may not have been persisted to the db yet
+        updated_state_map[current_state_key] = event
 
-        elif visibility == HistoryVisibility.INVITED:
-            # user can also see the event if they were *invited* at the time
-            # of the event.
-            return event if membership == Membership.INVITE else None
-
-        elif visibility == HistoryVisibility.SHARED and is_peeking:
-            # if the visibility is shared, users cannot see the event unless
-            # they have *subsequently* joined the room (or were members at the
-            # time, of course)
-            #
-            # XXX: if the user has subsequently joined and then left again,
-            # ideally we would share history up to the point they left. But
-            # we don't know when they left. We just treat it as though they
-            # never joined, and restrict access.
+    return updated_state_map
+
+
+def _check_client_allowed_to_see_event(
+    user_id: str,
+    event: EventBase,
+    clock: Clock,
+    filter_send_to_client: bool,
+    is_peeking: bool,
+    always_include_ids: FrozenSet[str],
+    sender_ignored: bool,
+    retention_policy: RetentionPolicy,
+    state: Optional[StateMap[EventBase]],
+    sender_erased: bool,
+) -> Optional[EventBase]:
+    """Check with the given user is allowed to see the given event
+
+    See `filter_events_for_client` for details about args
+
+    Args:
+        user_id
+        event
+        clock
+        filter_send_to_client
+        is_peeking
+        always_include_ids
+        sender_ignored: Whether the user is ignoring the event sender
+        retention_policy: The retention policy of the room
+        state: The state at the event, unless its an outlier
+        sender_erased: Whether the event sender has been marked as "erased"
+
+    Returns:
+        None if the user cannot see this event at all
+
+        a redacted copy of the event if they can only see a redacted
+        version
+
+        the original event if they can see it as normal.
+    """
+    # Only run some checks if these events aren't about to be sent to clients. This is
+    # because, if this is not the case, we're probably only checking if the users can
+    # see events in the room at that point in the DAG, and that shouldn't be decided
+    # on those checks.
+    if filter_send_to_client:
+        if (
+            _check_filter_send_to_client(event, clock, retention_policy, sender_ignored)
+            == _CheckFilter.DENIED
+        ):
             return None
 
-        # the visibility is either shared or world_readable, and the user was
-        # not a member at the time. We allow it, provided the original sender
-        # has not requested their data to be erased, in which case, we return
-        # a redacted version.
-        if erased_senders[event.sender]:
-            return prune_event(event)
+    if event.event_id in always_include_ids:
+        return event
+
+    # we need to handle outliers separately, since we don't have the room state.
+    if event.internal_metadata.outlier:
+        # Normally these can't be seen by clients, but we make an exception for
+        # for out-of-band membership events (eg, incoming invites, or rejections of
+        # said invite) for the user themselves.
+        if event.type == EventTypes.Member and event.state_key == user_id:
+            logger.debug("Returning out-of-band-membership event %s", event)
+            return event
+
+        return None
+
+    if state is None:
+        raise Exception("Missing state for non-outlier event")
 
+    # get the room_visibility at the time of the event.
+    visibility = get_effective_room_visibility_from_state(state)
+
+    # Check if the room has lax history visibility, allowing us to skip
+    # membership checks.
+    #
+    # We can only do this check if the sender has *not* been erased, as if they
+    # have we need to check the user's membership.
+    if (
+        not sender_erased
+        and _check_history_visibility(event, visibility, is_peeking)
+        == _CheckVisibility.ALLOWED
+    ):
         return event
 
-    # Check each event: gives an iterable of None or (a potentially modified)
-    # EventBase.
-    filtered_events = map(allowed, events)
+    membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
+    if not membership_result.allowed:
+        return None
 
-    # Turn it into a list and remove None entries before returning.
-    return [ev for ev in filtered_events if ev]
+    # If the sender has been erased and the user was not joined at the time, we
+    # must only return the redacted form.
+    if sender_erased and not membership_result.joined:
+        event = prune_event(event)
+
+    return event
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _CheckMembershipReturn:
+    "Return value of _check_membership"
+    allowed: bool
+    joined: bool
+
+
+def _check_membership(
+    user_id: str,
+    event: EventBase,
+    visibility: str,
+    state: StateMap[EventBase],
+    is_peeking: bool,
+) -> _CheckMembershipReturn:
+    """Check whether the user can see the event due to their membership
+
+    Returns:
+        True if they can, False if they can't, plus the membership of the user
+        at the event.
+    """
+    # If the event is the user's own membership event, use the 'most joined'
+    # membership
+    membership = None
+    if event.type == EventTypes.Member and event.state_key == user_id:
+        membership = event.content.get("membership", None)
+        if membership not in MEMBERSHIP_PRIORITY:
+            membership = "leave"
+
+        prev_content = event.unsigned.get("prev_content", {})
+        prev_membership = prev_content.get("membership", None)
+        if prev_membership not in MEMBERSHIP_PRIORITY:
+            prev_membership = "leave"
+
+        # Always allow the user to see their own leave events, otherwise
+        # they won't see the room disappear if they reject the invite
+        #
+        # (Note this doesn't work for out-of-band invite rejections, which don't
+        # have prev_state populated. They are handled above in the outlier code.)
+        if membership == "leave" and (
+            prev_membership == "join" or prev_membership == "invite"
+        ):
+            return _CheckMembershipReturn(True, membership == Membership.JOIN)
+
+        new_priority = MEMBERSHIP_PRIORITY.index(membership)
+        old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
+        if old_priority < new_priority:
+            membership = prev_membership
+
+    # otherwise, get the user's membership at the time of the event.
+    if membership is None:
+        membership_event = state.get((EventTypes.Member, user_id), None)
+        if membership_event:
+            membership = membership_event.membership
+
+    # if the user was a member of the room at the time of the event,
+    # they can see it.
+    if membership == Membership.JOIN:
+        return _CheckMembershipReturn(True, True)
+
+    # otherwise, it depends on the room visibility.
+
+    if visibility == HistoryVisibility.JOINED:
+        # we weren't a member at the time of the event, so we can't
+        # see this event.
+        return _CheckMembershipReturn(False, False)
+
+    elif visibility == HistoryVisibility.INVITED:
+        # user can also see the event if they were *invited* at the time
+        # of the event.
+        return _CheckMembershipReturn(membership == Membership.INVITE, False)
+
+    elif visibility == HistoryVisibility.SHARED and is_peeking:
+        # if the visibility is shared, users cannot see the event unless
+        # they have *subsequently* joined the room (or were members at the
+        # time, of course)
+        #
+        # XXX: if the user has subsequently joined and then left again,
+        # ideally we would share history up to the point they left. But
+        # we don't know when they left. We just treat it as though they
+        # never joined, and restrict access.
+        return _CheckMembershipReturn(False, False)
+
+    # The visibility is either shared or world_readable, and the user was
+    # not a member at the time. We allow it.
+    return _CheckMembershipReturn(True, False)
+
+
+class _CheckFilter(Enum):
+    MAYBE_ALLOWED = auto()
+    DENIED = auto()
+
+
+def _check_filter_send_to_client(
+    event: EventBase,
+    clock: Clock,
+    retention_policy: RetentionPolicy,
+    sender_ignored: bool,
+) -> _CheckFilter:
+    """Apply checks for sending events to client
+
+    Returns:
+        True if might be allowed to be sent to clients, False if definitely not.
+    """
+
+    if event.type == EventTypes.Dummy:
+        return _CheckFilter.DENIED
+
+    if not event.is_state() and sender_ignored:
+        return _CheckFilter.DENIED
+
+    # Until MSC2261 has landed we can't redact malicious alias events, so for
+    # now we temporarily filter out m.room.aliases entirely to mitigate
+    # abuse, while we spec a better solution to advertising aliases
+    # on rooms.
+    if event.type == EventTypes.Aliases:
+        return _CheckFilter.DENIED
+
+    # Don't try to apply the room's retention policy if the event is a state
+    # event, as MSC1763 states that retention is only considered for non-state
+    # events.
+    if not event.is_state():
+        max_lifetime = retention_policy.max_lifetime
+
+        if max_lifetime is not None:
+            oldest_allowed_ts = clock.time_msec() - max_lifetime
+
+            if event.origin_server_ts < oldest_allowed_ts:
+                return _CheckFilter.DENIED
+
+    return _CheckFilter.MAYBE_ALLOWED
+
+
+class _CheckVisibility(Enum):
+    ALLOWED = auto()
+    MAYBE_DENIED = auto()
+
+
+def _check_history_visibility(
+    event: EventBase, visibility: str, is_peeking: bool
+) -> _CheckVisibility:
+    """Check if event is allowed to be seen due to lax history visibility.
+
+    Returns:
+        True if user can definitely see the event, False if maybe not.
+    """
+    # Always allow history visibility events on boundaries. This is done
+    # by setting the effective visibility to the least restrictive
+    # of the old vs new.
+    if event.type == EventTypes.RoomHistoryVisibility:
+        prev_content = event.unsigned.get("prev_content", {})
+        prev_visibility = prev_content.get("history_visibility", None)
+
+        if prev_visibility not in VISIBILITY_PRIORITY:
+            prev_visibility = HistoryVisibility.SHARED
+
+        new_priority = VISIBILITY_PRIORITY.index(visibility)
+        old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
+        if old_priority < new_priority:
+            visibility = prev_visibility
+
+    if visibility == HistoryVisibility.SHARED and not is_peeking:
+        return _CheckVisibility.ALLOWED
+    elif visibility == HistoryVisibility.WORLD_READABLE:
+        return _CheckVisibility.ALLOWED
+
+    return _CheckVisibility.MAYBE_DENIED
 
 
 def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
     """Get the actual history vis, from a state map including the history_visibility event
-
     Handles missing and invalid history visibility events.
     """
     visibility_event = state.get(_HISTORY_VIS_KEY, None)