summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-09-14 18:14:17 -0500
committerEric Eastwood <erice@element.io>2022-09-14 18:14:17 -0500
commitdab846568de09742c038a9a5a5ed00539ff5aa38 (patch)
treeef4220faa96569d1d8d5f386c92bb4df2655eeed /synapse
parentAdd changelog (diff)
parentKeep track when we try and fail to process a pulled event (#13589) (diff)
downloadsynapse-madlittlemods/event_id_always_failed_to_fetch.tar.xz
Merge branch 'develop' into madlittlemods/event_id_always_failed_to_fetch github/madlittlemods/event_id_always_failed_to_fetch madlittlemods/event_id_always_failed_to_fetch
Conflicts:
	synapse/handlers/federation_event.py
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py5
-rwxr-xr-xsynapse/_scripts/generate_config.py2
-rw-r--r--synapse/_scripts/register_new_matrix_user.py102
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py18
-rw-r--r--synapse/api/auth.py242
-rw-r--r--synapse/api/constants.py10
-rw-r--r--synapse/api/errors.py58
-rw-r--r--synapse/api/filtering.py8
-rw-r--r--synapse/api/ratelimiting.py100
-rw-r--r--synapse/api/room_versions.py98
-rw-r--r--synapse/app/_base.py44
-rw-r--r--synapse/app/admin_cmd.py28
-rw-r--r--synapse/app/generic_worker.py49
-rw-r--r--synapse/app/homeserver.py25
-rw-r--r--synapse/app/phone_stats_home.py19
-rw-r--r--synapse/appservice/api.py14
-rw-r--r--synapse/config/_base.py73
-rw-r--r--synapse/config/account_validity.py2
-rw-r--r--synapse/config/emailconfig.py54
-rw-r--r--synapse/config/experimental.py13
-rw-r--r--synapse/config/key.py13
-rw-r--r--synapse/config/metrics.py29
-rw-r--r--synapse/config/ratelimiting.py47
-rw-r--r--synapse/config/registration.py43
-rw-r--r--synapse/config/repository.py35
-rw-r--r--synapse/config/server.py16
-rw-r--r--synapse/config/sso.py2
-rw-r--r--synapse/config/workers.py8
-rw-r--r--synapse/crypto/event_signing.py2
-rw-r--r--synapse/event_auth.py92
-rw-r--r--synapse/events/__init__.py12
-rw-r--r--synapse/events/builder.py14
-rw-r--r--synapse/events/snapshot.py7
-rw-r--r--synapse/events/spamcheck.py2
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/events/validator.py2
-rw-r--r--synapse/federation/federation_base.py24
-rw-r--r--synapse/federation/federation_client.py183
-rw-r--r--synapse/federation/federation_server.py57
-rw-r--r--synapse/federation/sender/__init__.py27
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server/_base.py7
-rw-r--r--synapse/federation/transport/server/federation.py3
-rw-r--r--synapse/handlers/admin.py1
-rw-r--r--synapse/handlers/appservice.py11
-rw-r--r--synapse/handlers/auth.py19
-rw-r--r--synapse/handlers/device.py39
-rw-r--r--synapse/handlers/directory.py34
-rw-r--r--synapse/handlers/e2e_keys.py98
-rw-r--r--synapse/handlers/e2e_room_keys.py4
-rw-r--r--synapse/handlers/event_auth.py9
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/federation.py178
-rw-r--r--synapse/handlers/federation_event.py719
-rw-r--r--synapse/handlers/identity.py232
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py95
-rw-r--r--synapse/handlers/pagination.py45
-rw-r--r--synapse/handlers/presence.py115
-rw-r--r--synapse/handlers/receipts.py7
-rw-r--r--synapse/handlers/register.py15
-rw-r--r--synapse/handlers/relations.py7
-rw-r--r--synapse/handlers/room.py149
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py209
-rw-r--r--synapse/handlers/room_summary.py6
-rw-r--r--synapse/handlers/send_email.py36
-rw-r--r--synapse/handlers/sync.py410
-rw-r--r--synapse/handlers/typing.py30
-rw-r--r--synapse/handlers/ui_auth/checkers.py21
-rw-r--r--synapse/http/matrixfederationclient.py9
-rw-r--r--synapse/http/server.py86
-rw-r--r--synapse/http/servlet.py40
-rw-r--r--synapse/http/site.py2
-rw-r--r--synapse/logging/opentracing.py269
-rw-r--r--synapse/metrics/__init__.py4
-rw-r--r--synapse/metrics/_legacy_exposition.py (renamed from synapse/metrics/_exposition.py)46
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/metrics/common_usage_metrics.py79
-rw-r--r--synapse/module_api/__init__.py83
-rw-r--r--synapse/notifier.py18
-rw-r--r--synapse/push/baserules.py529
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py78
-rw-r--r--synapse/push/clientformat.py68
-rw-r--r--synapse/push/push_rule_evaluator.py27
-rw-r--r--synapse/push/push_tools.py13
-rw-r--r--synapse/push/pusherpool.py27
-rw-r--r--synapse/replication/http/_base.py11
-rw-r--r--synapse/replication/slave/storage/_base.py58
-rw-r--r--synapse/replication/slave/storage/account_data.py22
-rw-r--r--synapse/replication/slave/storage/appservice.py25
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py20
-rw-r--r--synapse/replication/slave/storage/devices.py3
-rw-r--r--synapse/replication/slave/storage/directory.py21
-rw-r--r--synapse/replication/slave/storage/events.py3
-rw-r--r--synapse/replication/slave/storage/filtering.py5
-rw-r--r--synapse/replication/slave/storage/profile.py20
-rw-r--r--synapse/replication/slave/storage/push_rule.py1
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/slave/storage/receipts.py22
-rw-r--r--synapse/replication/slave/storage/registration.py21
-rw-r--r--synapse/replication/tcp/client.py17
-rw-r--r--synapse/replication/tcp/handler.py58
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/res/templates/account_previously_renewed.html13
-rw-r--r--synapse/res/templates/account_renewed.html13
-rw-r--r--synapse/res/templates/add_threepid.html11
-rw-r--r--synapse/res/templates/add_threepid_failure.html15
-rw-r--r--synapse/res/templates/add_threepid_success.html14
-rw-r--r--synapse/res/templates/auth_success.html4
-rw-r--r--synapse/res/templates/invalid_token.html13
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notif_mail.html2
-rw-r--r--synapse/res/templates/password_reset.html7
-rw-r--r--synapse/res/templates/password_reset_confirmation.html8
-rw-r--r--synapse/res/templates/password_reset_failure.html8
-rw-r--r--synapse/res/templates/password_reset_success.html7
-rw-r--r--synapse/res/templates/recaptcha.html4
-rw-r--r--synapse/res/templates/registration.html7
-rw-r--r--synapse/res/templates/registration_failure.html7
-rw-r--r--synapse/res/templates/registration_success.html8
-rw-r--r--synapse/res/templates/registration_token.html6
-rw-r--r--synapse/res/templates/sso_account_deactivated.html4
-rw-r--r--synapse/res/templates/sso_auth_account_details.html5
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html3
-rw-r--r--synapse/res/templates/sso_auth_confirm.html3
-rw-r--r--synapse/res/templates/sso_auth_success.html3
-rw-r--r--synapse/res/templates/sso_error.html3
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html2
-rw-r--r--synapse/res/templates/sso_new_user_consent.html3
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html3
-rw-r--r--synapse/res/templates/terms.html4
-rw-r--r--synapse/rest/admin/__init__.py4
-rw-r--r--synapse/rest/admin/_base.py10
-rw-r--r--synapse/rest/admin/media.py6
-rw-r--r--synapse/rest/admin/rooms.py117
-rw-r--r--synapse/rest/admin/users.py16
-rw-r--r--synapse/rest/client/account.py283
-rw-r--r--synapse/rest/client/devices.py27
-rw-r--r--synapse/rest/client/keys.py13
-rw-r--r--synapse/rest/client/login.py10
-rw-r--r--synapse/rest/client/models.py85
-rw-r--r--synapse/rest/client/notifications.py6
-rw-r--r--synapse/rest/client/profile.py4
-rw-r--r--synapse/rest/client/read_marker.py58
-rw-r--r--synapse/rest/client/receipts.py20
-rw-r--r--synapse/rest/client/register.py66
-rw-r--r--synapse/rest/client/room.py123
-rw-r--r--synapse/rest/client/room_keys.py13
-rw-r--r--synapse/rest/client/sendtodevice.py3
-rw-r--r--synapse/rest/client/sync.py12
-rw-r--r--synapse/rest/client/versions.py6
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py41
-rw-r--r--synapse/rest/media/v1/_base.py40
-rw-r--r--synapse/rest/media/v1/media_repository.py1
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py66
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py40
-rw-r--r--synapse/rest/models.py23
-rw-r--r--synapse/rest/synapse/client/password_reset.py8
-rw-r--r--synapse/server.py22
-rw-r--r--synapse/server_notices/server_notices_manager.py12
-rw-r--r--synapse/state/__init__.py281
-rw-r--r--synapse/state/v2.py81
-rw-r--r--synapse/static/client/login/index.html3
-rw-r--r--synapse/static/client/register/index.html3
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/background_updates.py6
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py65
-rw-r--r--synapse/storage/controllers/state.py66
-rw-r--r--synapse/storage/database.py105
-rw-r--r--synapse/storage/databases/main/__init__.py29
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/appservice.py58
-rw-r--r--synapse/storage/databases/main/cache.py33
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py58
-rw-r--r--synapse/storage/databases/main/event_federation.py58
-rw-r--r--synapse/storage/databases/main/event_push_actions.py623
-rw-r--r--synapse/storage/databases/main/events.py73
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py87
-rw-r--r--synapse/storage/databases/main/events_worker.py316
-rw-r--r--synapse/storage/databases/main/lock.py121
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/purge_events.py35
-rw-r--r--synapse/storage/databases/main/push_rule.py128
-rw-r--r--synapse/storage/databases/main/receipts.py108
-rw-r--r--synapse/storage/databases/main/registration.py8
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py31
-rw-r--r--synapse/storage/databases/main/roommember.py633
-rw-r--r--synapse/storage/databases/main/state.py17
-rw-r--r--synapse/storage/databases/main/stats.py86
-rw-r--r--synapse/storage/databases/main/stream.py36
-rw-r--r--synapse/storage/databases/main/transactions.py30
-rw-r--r--synapse/storage/databases/state/bg_updates.py3
-rw-r--r--synapse/storage/databases/state/store.py168
-rw-r--r--synapse/storage/engines/_base.py8
-rw-r--r--synapse/storage/engines/postgres.py7
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/schema/__init__.py17
-rw-r--r--synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py47
-rw-r--r--synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql17
-rw-r--r--synapse/storage/schema/main/delta/72/03remove_groups.sql31
-rw-r--r--synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres17
-rw-r--r--synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite40
-rw-r--r--synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql19
-rw-r--r--synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql19
-rw-r--r--synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql16
-rw-r--r--synapse/storage/schema/main/delta/72/06thread_notifications.sql30
-rw-r--r--synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py52
-rw-r--r--synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres30
-rw-r--r--synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite70
-rw-r--r--synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres23
-rw-r--r--synapse/storage/schema/main/delta/72/08thread_receipts.sql20
-rw-r--r--synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite56
-rw-r--r--synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql29
-rw-r--r--synapse/storage/schema/state/delta/30/state_stream.sql33
-rw-r--r--synapse/storage/state.py9
-rw-r--r--synapse/storage/util/id_generators.py13
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py7
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/types.py5
-rw-r--r--synapse/util/caches/__init__.py59
-rw-r--r--synapse/util/caches/deferred_cache.py346
-rw-r--r--synapse/util/caches/descriptors.py115
-rw-r--r--synapse/util/caches/dictionary_cache.py218
-rw-r--r--synapse/util/caches/lrucache.py145
-rw-r--r--synapse/util/caches/treecache.py41
-rw-r--r--synapse/util/cancellation.py56
-rw-r--r--synapse/util/metrics.py34
-rw-r--r--synapse/util/ratelimitutils.py214
-rw-r--r--synapse/util/rust.py84
-rw-r--r--synapse/visibility.py6
236 files changed, 8334 insertions, 4122 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index b1369aca8f..1bed6393bd 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -20,6 +20,8 @@ import json
 import os
 import sys
 
+from synapse.util.rust import check_rust_lib_up_to_date
+
 # Check that we're not running on an unsupported Python version.
 if sys.version_info < (3, 7):
     print("Synapse requires Python 3.7 or above.")
@@ -78,3 +80,6 @@ if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     from synapse.util.patch_inline_callbacks import do_patch
 
     do_patch()
+
+
+check_rust_lib_up_to_date()
diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py
index 08eb8ef114..06c11c60da 100755
--- a/synapse/_scripts/generate_config.py
+++ b/synapse/_scripts/generate_config.py
@@ -33,7 +33,7 @@ def main() -> None:
     parser.add_argument(
         "--report-stats",
         action="store",
-        help="Whether the generated config reports anonymized usage statistics",
+        help="Whether the generated config reports homeserver usage statistics",
         choices=["yes", "no"],
     )
 
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 092601f530..0c4504d5d8 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -1,6 +1,6 @@
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector
-# Copyright 2021 The Matrix.org Foundation C.I.C.
+# Copyright 2021-22 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,11 +20,22 @@ import hashlib
 import hmac
 import logging
 import sys
-from typing import Callable, Optional
+from typing import Any, Callable, Dict, Optional
 
 import requests
 import yaml
 
+_CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+Conflicting options 'registration_shared_secret' and 'registration_shared_secret_path'
+are both defined in config file.
+"""
+
+_NO_SHARED_SECRET_OPTS_ERROR = """\
+No 'registration_shared_secret' or 'registration_shared_secret_path' defined in config.
+"""
+
+_DEFAULT_SERVER_URL = "http://localhost:8008"
+
 
 def request_registration(
     user: str,
@@ -203,31 +214,104 @@ def main() -> None:
 
     parser.add_argument(
         "server_url",
-        default="https://localhost:8448",
         nargs="?",
-        help="URL to use to talk to the homeserver. Defaults to "
-        " 'https://localhost:8448'.",
+        help="URL to use to talk to the homeserver. By default, tries to find a "
+        "suitable URL from the configuration file. Otherwise, defaults to "
+        f"'{_DEFAULT_SERVER_URL}'.",
     )
 
     args = parser.parse_args()
 
     if "config" in args and args.config:
         config = yaml.safe_load(args.config)
-        secret = config.get("registration_shared_secret", None)
+
+    if args.shared_secret:
+        secret = args.shared_secret
+    else:
+        # argparse should check that we have either config or shared secret
+        assert config
+
+        secret = config.get("registration_shared_secret")
+        secret_file = config.get("registration_shared_secret_path")
+        if secret_file:
+            if secret:
+                print(_CONFLICTING_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
+                sys.exit(1)
+            secret = _read_file(secret_file, "registration_shared_secret_path").strip()
         if not secret:
-            print("No 'registration_shared_secret' defined in config.")
+            print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
             sys.exit(1)
+
+    if args.server_url:
+        server_url = args.server_url
+    elif config:
+        server_url = _find_client_listener(config)
+        if not server_url:
+            server_url = _DEFAULT_SERVER_URL
+            print(
+                "Unable to find a suitable HTTP listener in the configuration file. "
+                f"Trying {server_url} as a last resort.",
+                file=sys.stderr,
+            )
     else:
-        secret = args.shared_secret
+        server_url = _DEFAULT_SERVER_URL
+        print(
+            f"No server url or configuration file given. Defaulting to {server_url}.",
+            file=sys.stderr,
+        )
 
     admin = None
     if args.admin or args.no_admin:
         admin = args.admin
 
     register_new_user(
-        args.user, args.password, args.server_url, secret, admin, args.user_type
+        args.user, args.password, server_url, secret, admin, args.user_type
     )
 
 
+def _read_file(file_path: Any, config_path: str) -> str:
+    """Check the given file exists, and read it into a string
+
+    If it does not, exit with an error indicating the problem
+
+    Args:
+        file_path: the file to be read
+        config_path: where in the configuration file_path came from, so that a useful
+           error can be emitted if it does not exist.
+    Returns:
+        content of the file.
+    """
+    if not isinstance(file_path, str):
+        print(f"{config_path} setting is not a string", file=sys.stderr)
+        sys.exit(1)
+
+    try:
+        with open(file_path) as file_stream:
+            return file_stream.read()
+    except OSError as e:
+        print(f"Error accessing file {file_path}: {e}", file=sys.stderr)
+        sys.exit(1)
+
+
+def _find_client_listener(config: Dict[str, Any]) -> Optional[str]:
+    # try to find a listener in the config. Returns a host:port pair
+    for listener in config.get("listeners", []):
+        if listener.get("type") != "http" or listener.get("tls", False):
+            continue
+
+        if not any(
+            name == "client"
+            for resource in listener.get("resources", [])
+            for name in resource.get("names", [])
+        ):
+            continue
+
+        # TODO: consider bind_addresses
+        return f"http://localhost:{listener['port']}"
+
+    # no suitable listeners?
+    return None
+
+
 if __name__ == "__main__":
     main()
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 26834a437e..30983c47fb 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -67,6 +67,7 @@ from synapse.storage.databases.main.media_repository import (
 )
 from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore
 from synapse.storage.databases.main.registration import (
     RegistrationBackgroundUpdateStore,
     find_max_generated_user_id_localpart,
@@ -166,22 +167,6 @@ IGNORED_TABLES = {
     "ui_auth_sessions",
     "ui_auth_sessions_credentials",
     "ui_auth_sessions_ips",
-    # Groups/communities is no longer supported.
-    "group_attestations_remote",
-    "group_attestations_renewals",
-    "group_invites",
-    "group_roles",
-    "group_room_categories",
-    "group_rooms",
-    "group_summary_roles",
-    "group_summary_room_categories",
-    "group_summary_rooms",
-    "group_summary_users",
-    "group_users",
-    "groups",
-    "local_group_membership",
-    "local_group_updates",
-    "remote_profile_cache",
 }
 
 
@@ -219,6 +204,7 @@ class Store(
     PushRuleStore,
     PusherWorkerStore,
     PresenceBackgroundUpdateStore,
+    ReceiptsBackgroundUpdateStore,
 ):
     def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6e6eaf3805..3d7f986ac7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -26,13 +26,20 @@ from synapse.api.errors import (
     Codes,
     InvalidClientTokenError,
     MissingClientTokenError,
+    UnstableSpecAuthError,
 )
 from synapse.appservice import ApplicationService
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import active_span, force_tracing, start_active_span
-from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, UserID, create_requester
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    force_tracing,
+    start_active_span,
+    trace,
+)
+from synapse.types import Requester, create_requester
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -64,14 +71,14 @@ class Auth:
     async def check_user_in_room(
         self,
         room_id: str,
-        user_id: str,
+        requester: Requester,
         allow_departed_users: bool = False,
     ) -> Tuple[str, Optional[str]]:
         """Check if the user is in the room, or was at some point.
         Args:
             room_id: The room to check.
 
-            user_id: The user to check.
+            requester: The user making the request, according to the access token.
 
             current_state: Optional map of the current state of the room.
                 If provided then that map is used to check whether they are a
@@ -88,6 +95,7 @@ class Auth:
             membership event ID of the user.
         """
 
+        user_id = requester.user.to_string()
         (
             membership,
             member_event_id,
@@ -106,9 +114,13 @@ class Auth:
                 forgot = await self.store.did_forget(user_id, room_id)
                 if not forgot:
                     return membership, member_event_id
+        raise UnstableSpecAuthError(
+            403,
+            "User %s not in room %s" % (user_id, room_id),
+            errcode=Codes.NOT_JOINED,
+        )
 
-        raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
-
+    @cancellable
     async def get_user_by_req(
         self,
         request: SynapseRequest,
@@ -150,6 +162,12 @@ class Auth:
                 parent_span.set_tag(
                     "authenticated_entity", requester.authenticated_entity
                 )
+                # We tag the Synapse instance name so that it's an easy jumping
+                # off point into the logs. Can also be used to filter for an
+                # instance that is under load.
+                parent_span.set_tag(
+                    SynapseTags.INSTANCE_NAME, self.hs.get_instance_name()
+                )
                 parent_span.set_tag("user_id", requester.user.to_string())
                 if requester.device_id is not None:
                     parent_span.set_tag("device_id", requester.device_id)
@@ -157,6 +175,7 @@ class Auth:
                     parent_span.set_tag("appservice_id", requester.app_service.id)
             return requester
 
+    @cancellable
     async def _wrapped_get_user_by_req(
         self,
         request: SynapseRequest,
@@ -173,96 +192,69 @@ class Auth:
 
             access_token = self.get_access_token_from_request(request)
 
-            (
-                user_id,
-                device_id,
-                app_service,
-            ) = await self._get_appservice_user_id_and_device_id(request)
-            if user_id and app_service:
-                if ip_addr and self._track_appservice_user_ips:
-                    await self.store.insert_client_ip(
-                        user_id=user_id,
-                        access_token=access_token,
-                        ip=ip_addr,
-                        user_agent=user_agent,
-                        device_id="dummy-device"
-                        if device_id is None
-                        else device_id,  # stubbed
-                    )
-
-                requester = create_requester(
-                    user_id, app_service=app_service, device_id=device_id
+            # First check if it could be a request from an appservice
+            requester = await self._get_appservice_user(request)
+            if not requester:
+                # If not, it should be from a regular user
+                requester = await self.get_user_by_access_token(
+                    access_token, allow_expired=allow_expired
                 )
 
-                request.requester = user_id
-                return requester
-
-            user_info = await self.get_user_by_access_token(
-                access_token, allow_expired=allow_expired
-            )
-            token_id = user_info.token_id
-            is_guest = user_info.is_guest
-            shadow_banned = user_info.shadow_banned
-
-            # Deny the request if the user account has expired.
-            if not allow_expired:
-                if await self._account_validity_handler.is_user_expired(
-                    user_info.user_id
-                ):
-                    # Raise the error if either an account validity module has determined
-                    # the account has expired, or the legacy account validity
-                    # implementation is enabled and determined the account has expired
-                    raise AuthError(
-                        403,
-                        "User account has expired",
-                        errcode=Codes.EXPIRED_ACCOUNT,
-                    )
-
-            device_id = user_info.device_id
-
-            if access_token and ip_addr:
+                # Deny the request if the user account has expired.
+                # This check is only done for regular users, not appservice ones.
+                if not allow_expired:
+                    if await self._account_validity_handler.is_user_expired(
+                        requester.user.to_string()
+                    ):
+                        # Raise the error if either an account validity module has determined
+                        # the account has expired, or the legacy account validity
+                        # implementation is enabled and determined the account has expired
+                        raise AuthError(
+                            403,
+                            "User account has expired",
+                            errcode=Codes.EXPIRED_ACCOUNT,
+                        )
+
+            if ip_addr and (
+                not requester.app_service or self._track_appservice_user_ips
+            ):
+                # XXX(quenting): I'm 95% confident that we could skip setting the
+                # device_id to "dummy-device" for appservices, and that the only impact
+                # would be some rows which whould not deduplicate in the 'user_ips'
+                # table during the transition
+                recorded_device_id = (
+                    "dummy-device"
+                    if requester.device_id is None and requester.app_service is not None
+                    else requester.device_id
+                )
                 await self.store.insert_client_ip(
-                    user_id=user_info.token_owner,
+                    user_id=requester.authenticated_entity,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
-                    device_id=device_id,
+                    device_id=recorded_device_id,
                 )
+
                 # Track also the puppeted user client IP if enabled and the user is puppeting
                 if (
-                    user_info.user_id != user_info.token_owner
+                    requester.user.to_string() != requester.authenticated_entity
                     and self._track_puppeted_user_ips
                 ):
                     await self.store.insert_client_ip(
-                        user_id=user_info.user_id,
+                        user_id=requester.user.to_string(),
                         access_token=access_token,
                         ip=ip_addr,
                         user_agent=user_agent,
-                        device_id=device_id,
+                        device_id=requester.device_id,
                     )
 
-            if is_guest and not allow_guest:
+            if requester.is_guest and not allow_guest:
                 raise AuthError(
                     403,
                     "Guest access not allowed",
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            # Mark the token as used. This is used to invalidate old refresh
-            # tokens after some time.
-            if not user_info.token_used and token_id is not None:
-                await self.store.mark_access_token_as_used(token_id)
-
-            requester = create_requester(
-                user_info.user_id,
-                token_id,
-                is_guest,
-                shadow_banned,
-                device_id,
-                app_service=app_service,
-                authenticated_entity=user_info.token_owner,
-            )
-
             request.requester = requester
             return requester
         except KeyError:
@@ -299,9 +291,8 @@ class Auth:
                 403, "Application service has not registered this user (%s)" % user_id
             )
 
-    async def _get_appservice_user_id_and_device_id(
-        self, request: Request
-    ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+    @cancellable
+    async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
         """
         Given a request, reads the request parameters to determine:
         - whether it's an application service that's making this request
@@ -316,15 +307,13 @@ class Auth:
              Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
 
         Returns:
-            3-tuple of
-            (user ID?, device ID?, application service?)
+            the application service `Requester` of that request
 
         Postconditions:
-        - If an application service is returned, so is a user ID
-        - A user ID is never returned without an application service
-        - A device ID is never returned without a user ID or an application service
-        - The returned application service, if present, is permitted to control the
-          returned user ID.
+        - The `app_service` field in the returned `Requester` is set
+        - The `user_id` field in the returned `Requester` is either the application
+          service sender or the controlled user set by the `user_id` URI parameter
+        - The returned application service is permitted to control the returned user ID.
         - The returned device ID, if present, has been checked to be a valid device ID
           for the returned user ID.
         """
@@ -334,12 +323,12 @@ class Auth:
             self.get_access_token_from_request(request)
         )
         if app_service is None:
-            return None, None, None
+            return None
 
         if app_service.ip_range_whitelist:
             ip_address = IPAddress(request.getClientAddress().host)
             if ip_address not in app_service.ip_range_whitelist:
-                return None, None, None
+                return None
 
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
@@ -373,13 +362,15 @@ class Auth:
                     Codes.EXCLUSIVE,
                 )
 
-        return effective_user_id, effective_device_id, app_service
+        return create_requester(
+            effective_user_id, app_service=app_service, device_id=effective_device_id
+        )
 
     async def get_user_by_access_token(
         self,
         token: str,
         allow_expired: bool = False,
-    ) -> TokenLookupResult:
+    ) -> Requester:
         """Validate access token and get user_id from it
 
         Args:
@@ -396,9 +387,9 @@ class Auth:
 
         # First look in the database to see if the access token is present
         # as an opaque token.
-        r = await self.store.get_user_by_access_token(token)
-        if r:
-            valid_until_ms = r.valid_until_ms
+        user_info = await self.store.get_user_by_access_token(token)
+        if user_info:
+            valid_until_ms = user_info.valid_until_ms
             if (
                 not allow_expired
                 and valid_until_ms is not None
@@ -410,7 +401,20 @@ class Auth:
                     msg="Access token has expired", soft_logout=True
                 )
 
-            return r
+            # Mark the token as used. This is used to invalidate old refresh
+            # tokens after some time.
+            await self.store.mark_access_token_as_used(user_info.token_id)
+
+            requester = create_requester(
+                user_id=user_info.user_id,
+                access_token_id=user_info.token_id,
+                is_guest=user_info.is_guest,
+                shadow_banned=user_info.shadow_banned,
+                device_id=user_info.device_id,
+                authenticated_entity=user_info.token_owner,
+            )
+
+            return requester
 
         # If the token isn't found in the database, then it could still be a
         # macaroon for a guest, so we check that here.
@@ -436,11 +440,12 @@ class Auth:
                     "Guest access token used for regular user"
                 )
 
-            return TokenLookupResult(
+            return create_requester(
                 user_id=user_id,
                 is_guest=True,
                 # all guests get the same device id
                 device_id=GUEST_DEVICE_ID,
+                authenticated_entity=user_id,
             )
         except (
             pymacaroons.exceptions.MacaroonException,
@@ -454,41 +459,33 @@ class Auth:
             )
             raise InvalidClientTokenError("Invalid access token passed.")
 
-    def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
-        token = self.get_access_token_from_request(request)
-        service = self.store.get_app_service_by_token(token)
-        if not service:
-            logger.warning("Unrecognised appservice access token.")
-            raise InvalidClientTokenError()
-        request.requester = create_requester(service.sender, app_service=service)
-        return service
-
-    async def is_server_admin(self, user: UserID) -> bool:
+    async def is_server_admin(self, requester: Requester) -> bool:
         """Check if the given user is a local server admin.
 
         Args:
-            user: user to check
+            requester: The user making the request, according to the access token.
 
         Returns:
             True if the user is an admin
         """
-        return await self.store.is_server_admin(user)
+        return await self.store.is_server_admin(requester.user)
 
-    async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
+    async def check_can_change_room_list(
+        self, room_id: str, requester: Requester
+    ) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
 
         Args:
-            room_id
-            user
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
         """
 
-        is_admin = await self.is_server_admin(user)
+        is_admin = await self.is_server_admin(requester)
         if is_admin:
             return True
 
-        user_id = user.to_string()
-        await self.check_user_in_room(room_id, user_id)
+        await self.check_user_in_room(room_id, requester)
 
         # We currently require the user is a "moderator" in the room. We do this
         # by checking if they would (theoretically) be able to change the
@@ -507,7 +504,9 @@ class Auth:
         send_level = event_auth.get_send_level(
             EventTypes.CanonicalAlias, "", power_level_event
         )
-        user_level = event_auth.get_user_power_level(user_id, auth_events)
+        user_level = event_auth.get_user_power_level(
+            requester.user.to_string(), auth_events
+        )
 
         return user_level >= send_level
 
@@ -526,6 +525,7 @@ class Auth:
         return bool(query_params) or bool(auth_headers)
 
     @staticmethod
+    @cancellable
     def get_access_token_from_request(request: Request) -> str:
         """Extracts the access_token from the request.
 
@@ -563,17 +563,18 @@ class Auth:
 
             return query_params[0].decode("ascii")
 
+    @trace
     async def check_user_in_room_or_world_readable(
-        self, room_id: str, user_id: str, allow_departed_users: bool = False
+        self, room_id: str, requester: Requester, allow_departed_users: bool = False
     ) -> Tuple[str, Optional[str]]:
         """Checks that the user is or was in the room or the room is world
         readable. If it isn't then an exception is raised.
 
         Args:
-            room_id: room to check
-            user_id: user to check
-            allow_departed_users: if True, accept users that were previously
-                members but have now departed
+            room_id: The room to check.
+            requester: The user making the request, according to the access token.
+            allow_departed_users: If True, accept users that were previously
+                members but have now departed.
 
         Returns:
             Resolves to the current membership of the user in the room and the
@@ -588,7 +589,7 @@ class Auth:
             #  * The user is a guest user, and has joined the room
             # else it will throw.
             return await self.check_user_in_room(
-                room_id, user_id, allow_departed_users=allow_departed_users
+                room_id, requester, allow_departed_users=allow_departed_users
             )
         except AuthError:
             visibility = await self._storage_controllers.state.get_current_state_event(
@@ -600,8 +601,9 @@ class Auth:
                 == HistoryVisibility.WORLD_READABLE
             ):
                 return Membership.JOIN, None
-            raise AuthError(
+            raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
-                % (user_id, room_id),
+                % (requester.user, room_id),
+                errcode=Codes.NOT_JOINED,
             )
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 2653764119..c178ddf070 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -216,11 +216,11 @@ class EventContentFields:
     MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical"
     # For "insertion" events to indicate what the next batch ID should be in
     # order to connect to it
-    MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id"
+    MSC2716_NEXT_BATCH_ID: Final = "next_batch_id"
     # Used on "batch" events to indicate which insertion event it connects to
-    MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id"
+    MSC2716_BATCH_ID: Final = "batch_id"
     # For "marker" events
-    MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion"
+    MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference"
 
     # The authorising user for joining a restricted room.
     AUTHORISING_USER: Final = "join_authorised_via_users_server"
@@ -257,7 +257,7 @@ class GuestAccess:
 
 class ReceiptTypes:
     READ: Final = "m.read"
-    READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
+    READ_PRIVATE: Final = "m.read.private"
     FULLY_READ: Final = "m.fully_read"
 
 
@@ -268,4 +268,4 @@ class PublicRoomsFilterFields:
     """
 
     GENERIC_SEARCH_TERM: Final = "generic_search_term"
-    ROOM_TYPES: Final = "org.matrix.msc3827.room_types"
+    ROOM_TYPES: Final = "room_types"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 1c74e131f2..e6dea89c6d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -26,6 +26,7 @@ from twisted.web import http
 from synapse.util import json_decoder
 
 if typing.TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
     from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -80,6 +81,12 @@ class Codes(str, Enum):
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
 
+    # Part of MSC3848
+    # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
+    ALREADY_JOINED = "ORG.MATRIX.MSC3848.ALREADY_JOINED"
+    NOT_JOINED = "ORG.MATRIX.MSC3848.NOT_JOINED"
+    INSUFFICIENT_POWER = "ORG.MATRIX.MSC3848.INSUFFICIENT_POWER"
+
     # The account has been suspended on the server.
     # By opposition to `USER_DEACTIVATED`, this is a reversible measure
     # that can possibly be appealed and reverted.
@@ -167,7 +174,7 @@ class SynapseError(CodeMessageException):
         else:
             self._additional_fields = dict(additional_fields)
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, **self._additional_fields)
 
 
@@ -213,7 +220,7 @@ class ConsentNotGivenError(SynapseError):
         )
         self._consent_uri = consent_uri
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
 
 
@@ -307,6 +314,37 @@ class AuthError(SynapseError):
         super().__init__(code, msg, errcode, additional_fields)
 
 
+class UnstableSpecAuthError(AuthError):
+    """An error raised when a new error code is being proposed to replace a previous one.
+    This error will return a "org.matrix.unstable.errcode" property with the new error code,
+    with the previous error code still being defined in the "errcode" property.
+
+    This error will include `org.matrix.msc3848.unstable.errcode` in the C-S error body.
+    """
+
+    def __init__(
+        self,
+        code: int,
+        msg: str,
+        errcode: str,
+        previous_errcode: str = Codes.FORBIDDEN,
+        additional_fields: Optional[dict] = None,
+    ):
+        self.previous_errcode = previous_errcode
+        super().__init__(code, msg, errcode, additional_fields)
+
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+        fields = {}
+        if config is not None and config.experimental.msc3848_enabled:
+            fields["org.matrix.msc3848.unstable.errcode"] = self.errcode
+        return cs_error(
+            self.msg,
+            self.previous_errcode,
+            **fields,
+            **self._additional_fields,
+        )
+
+
 class InvalidClientCredentialsError(SynapseError):
     """An error raised when there was a problem with the authorisation credentials
     in a client request.
@@ -338,8 +376,8 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
         super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
         self._soft_logout = soft_logout
 
-    def error_dict(self) -> "JsonDict":
-        d = super().error_dict()
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
+        d = super().error_dict(config)
         d["soft_logout"] = self._soft_logout
         return d
 
@@ -362,7 +400,7 @@ class ResourceLimitError(SynapseError):
         self.limit_type = limit_type
         super().__init__(code, msg, errcode=errcode)
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(
             self.msg,
             self.errcode,
@@ -397,7 +435,7 @@ class InvalidCaptchaError(SynapseError):
         super().__init__(code, msg, errcode)
         self.error_url = error_url
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, error_url=self.error_url)
 
 
@@ -414,7 +452,7 @@ class LimitExceededError(SynapseError):
         super().__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
 
 
@@ -429,7 +467,7 @@ class RoomKeysVersionError(SynapseError):
         super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
         self.current_version = current_version
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, current_version=self.current_version)
 
 
@@ -469,7 +507,7 @@ class IncompatibleRoomVersionError(SynapseError):
 
         self._room_version = room_version
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         return cs_error(self.msg, self.errcode, room_version=self._room_version)
 
 
@@ -515,7 +553,7 @@ class UnredactedContentDeletedError(SynapseError):
         )
         self.content_keep_ms = content_keep_ms
 
-    def error_dict(self) -> "JsonDict":
+    def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
         extra = {}
         if self.content_keep_ms is not None:
             extra = {"fi.mau.msc2815.content_keep_ms": self.content_keep_ms}
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index b007147519..f7f46f8d80 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -140,13 +140,13 @@ USER_FILTER_SCHEMA = {
 
 
 @FormatChecker.cls_checks("matrix_room_id")
-def matrix_room_id_validator(room_id_str: str) -> RoomID:
-    return RoomID.from_string(room_id_str)
+def matrix_room_id_validator(room_id: object) -> bool:
+    return isinstance(room_id, str) and RoomID.is_valid(room_id)
 
 
 @FormatChecker.cls_checks("matrix_user_id")
-def matrix_user_id_validator(user_id_str: str) -> UserID:
-    return UserID.from_string(user_id_str)
+def matrix_user_id_validator(user_id: object) -> bool:
+    return isinstance(user_id, str) and UserID.is_valid(user_id)
 
 
 class Filtering:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 54d13026c9..044c7d4926 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,7 +17,7 @@ from collections import OrderedDict
 from typing import Hashable, Optional, Tuple
 
 from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import RateLimitConfig
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.storage.databases.main import DataStore
 from synapse.types import Requester
 from synapse.util import Clock
@@ -27,6 +27,33 @@ class Ratelimiter:
     """
     Ratelimit actions marked by arbitrary keys.
 
+    (Note that the source code speaks of "actions" and "burst_count" rather than
+    "tokens" and a "bucket_size".)
+
+    This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
+    containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
+    permitted requests for that key. Each bucket starts empty, and gradually leaks
+    tokens at a rate of `rate_hz`.
+
+    Upon an incoming request, we must determine:
+    - the key that this request falls under (which bucket to inspect), and
+    - the cost C of this request in tokens.
+    Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
+    the request is permitted and `cost` tokens are added to the bucket.
+    Otherwise the request is denied, and the bucket continues to hold T tokens.
+
+    This means that the limiter enforces an average request frequency of `rate_hz`,
+    while accumulating a buffer of up to `burst_count` requests which can be consumed
+    instantaneously.
+
+    The tricky bit is the leaking. We do not want to have a periodic process which
+    leaks every bucket! Instead, we track
+    - the time point when the bucket was last completely empty, and
+    - how many tokens have added to the bucket permitted since then.
+    Then for each incoming request, we can calculate how many tokens have leaked
+    since this time point, and use that to decide if we should accept or reject the
+    request.
+
     Args:
         clock: A homeserver clock, for retrieving the current time
         rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ class Ratelimiter:
         self.burst_count = burst_count
         self.store = store
 
-        # A ordered dictionary keeping track of actions, when they were last
-        # performed and how often. Each entry is a mapping from a key of arbitrary type
-        # to a tuple representing:
-        #   * How many times an action has occurred since a point in time
-        #   * The point in time
-        #   * The rate_hz of this particular entry. This can vary per request
+        # An ordered dictionary representing the token buckets tracked by this rate
+        # limiter. Each entry maps a key of arbitrary type to a tuple representing:
+        #   * The number of tokens currently in the bucket,
+        #   * The time point when the bucket was last completely empty, and
+        #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided."""
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket."""
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -88,11 +131,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -121,7 +160,7 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start
@@ -164,6 +203,37 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
+    def record_action(
+        self,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
+        n_actions: int = 1,
+        _time_now_s: Optional[float] = None,
+    ) -> None:
+        """Record that an action(s) took place, even if they violate the rate limit.
+
+        This is useful for tracking the frequency of events that happen across
+        federation which we still want to impose local rate limits on. For instance, if
+        we are alice.com monitoring a particular room, we cannot prevent bob.com
+        from joining users to that room. However, we can track the number of recent
+        joins in the room and refuse to serve new joins ourselves if there have been too
+        many in the room across both homeservers.
+
+        Args:
+            requester: The requester that is doing the action, if any.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
+            n_actions: The number of times the user wants to do this action. If the user
+                cannot do all of the actions, the user's action count is not incremented
+                at all.
+            _time_now_s: The current time. Optional, defaults to the current time according
+                to self.clock. Only used by tests.
+        """
+        key = self._get_key(requester, key)
+        time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+        action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
+        self.actions[key] = (action_count + n_actions, time_start, rate_hz)
+
     def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
@@ -244,8 +314,8 @@ class RequestRatelimiter:
         self,
         store: DataStore,
         clock: Clock,
-        rc_message: RateLimitConfig,
-        rc_admin_redaction: Optional[RateLimitConfig],
+        rc_message: RatelimitSettings,
+        rc_admin_redaction: Optional[RatelimitSettings],
     ):
         self.store = store
         self.clock = clock
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 3f85d61b46..e37acb0f1e 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -19,18 +19,23 @@ import attr
 
 class EventFormatVersions:
     """This is an internal enum for tracking the version of the event format,
-    independently from the room version.
+    independently of the room version.
+
+    To reduce confusion, the event format versions are named after the room
+    versions that they were used or introduced in.
+    The concept of an 'event format version' is specific to Synapse (the
+    specification does not mention this term.)
     """
 
-    V1 = 1  # $id:server event id format
-    V2 = 2  # MSC1659-style $hash event id format: introduced for room v3
-    V3 = 3  # MSC1884-style $hash format: introduced for room v4
+    ROOM_V1_V2 = 1  # $id:server event id format: used for room v1 and v2
+    ROOM_V3 = 2  # MSC1659-style $hash event id format: used for room v3
+    ROOM_V4_PLUS = 3  # MSC1884-style $hash format: introduced for room v4
 
 
 KNOWN_EVENT_FORMAT_VERSIONS = {
-    EventFormatVersions.V1,
-    EventFormatVersions.V2,
-    EventFormatVersions.V3,
+    EventFormatVersions.ROOM_V1_V2,
+    EventFormatVersions.ROOM_V3,
+    EventFormatVersions.ROOM_V4_PLUS,
 }
 
 
@@ -84,13 +89,15 @@ class RoomVersion:
     # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
     # knocks and restricted join rules into the same join condition.
     msc3787_knock_restricted_join_rule: bool
+    # MSC3667: Enforce integer power levels
+    msc3667_int_only_power_levels: bool
 
 
 class RoomVersions:
     V1 = RoomVersion(
         "1",
         RoomDisposition.STABLE,
-        EventFormatVersions.V1,
+        EventFormatVersions.ROOM_V1_V2,
         StateResolutionVersions.V1,
         enforce_key_validity=False,
         special_case_aliases_auth=True,
@@ -103,11 +110,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V2 = RoomVersion(
         "2",
         RoomDisposition.STABLE,
-        EventFormatVersions.V1,
+        EventFormatVersions.ROOM_V1_V2,
         StateResolutionVersions.V2,
         enforce_key_validity=False,
         special_case_aliases_auth=True,
@@ -120,11 +128,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V3 = RoomVersion(
         "3",
         RoomDisposition.STABLE,
-        EventFormatVersions.V2,
+        EventFormatVersions.ROOM_V3,
         StateResolutionVersions.V2,
         enforce_key_validity=False,
         special_case_aliases_auth=True,
@@ -137,11 +146,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V4 = RoomVersion(
         "4",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=False,
         special_case_aliases_auth=True,
@@ -154,11 +164,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V5 = RoomVersion(
         "5",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=True,
@@ -171,11 +182,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V6 = RoomVersion(
         "6",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -188,11 +200,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
         RoomDisposition.UNSTABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -205,11 +218,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V7 = RoomVersion(
         "7",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -222,11 +236,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V8 = RoomVersion(
         "8",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -239,11 +254,12 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V9 = RoomVersion(
         "9",
         RoomDisposition.STABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -256,28 +272,30 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
-    MSC2716v3 = RoomVersion(
-        "org.matrix.msc2716v3",
+    MSC3787 = RoomVersion(
+        "org.matrix.msc3787",
         RoomDisposition.UNSTABLE,
-        EventFormatVersions.V3,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
         msc2176_redaction_rules=False,
-        msc3083_join_rules=False,
-        msc3375_redaction_rules=False,
+        msc3083_join_rules=True,
+        msc3375_redaction_rules=True,
         msc2403_knocking=True,
-        msc2716_historical=True,
-        msc2716_redactions=True,
-        msc3787_knock_restricted_join_rule=False,
+        msc2716_historical=False,
+        msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=False,
     )
-    MSC3787 = RoomVersion(
-        "org.matrix.msc3787",
-        RoomDisposition.UNSTABLE,
-        EventFormatVersions.V3,
+    V10 = RoomVersion(
+        "10",
+        RoomDisposition.STABLE,
+        EventFormatVersions.ROOM_V4_PLUS,
         StateResolutionVersions.V2,
         enforce_key_validity=True,
         special_case_aliases_auth=False,
@@ -290,6 +308,25 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=True,
+    )
+    MSC2716v4 = RoomVersion(
+        "org.matrix.msc2716v4",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.ROOM_V4_PLUS,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
+        msc3375_redaction_rules=False,
+        msc2403_knocking=True,
+        msc2716_historical=True,
+        msc2716_redactions=True,
+        msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
 
 
@@ -306,8 +343,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.V7,
         RoomVersions.V8,
         RoomVersions.V9,
-        RoomVersions.MSC2716v3,
         RoomVersions.MSC3787,
+        RoomVersions.V10,
+        RoomVersions.MSC2716v4,
     )
 }
 
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 923891ae0d..9a24bed0a0 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -266,15 +266,48 @@ def register_start(
     reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
 
 
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(
+    bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool
+) -> None:
     """
     Start Prometheus metrics server.
     """
-    from synapse.metrics import RegistryProxy, start_http_server
+    from prometheus_client import start_http_server as start_http_server_prometheus
+
+    from synapse.metrics import (
+        RegistryProxy,
+        start_http_server as start_http_server_legacy,
+    )
 
     for host in bind_addresses:
         logger.info("Starting metrics listener on %s:%d", host, port)
-        start_http_server(port, addr=host, registry=RegistryProxy)
+        if enable_legacy_metric_names:
+            start_http_server_legacy(port, addr=host, registry=RegistryProxy)
+        else:
+            _set_prometheus_client_use_created_metrics(False)
+            start_http_server_prometheus(port, addr=host, registry=RegistryProxy)
+
+
+def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
+    """
+    Sets whether prometheus_client should expose `_created`-suffixed metrics for
+    all gauges, histograms and summaries.
+    There is no programmatic way to disable this without poking at internals;
+    the proper way is to use an environment variable which prometheus_client
+    loads at import time.
+
+    The motivation for disabling these `_created` metrics is that they're
+    a waste of space as they're not useful but they take up space in Prometheus.
+    """
+
+    import prometheus_client.metrics
+
+    if hasattr(prometheus_client.metrics, "_use_created"):
+        prometheus_client.metrics._use_created = new_value
+    else:
+        logger.error(
+            "Can't disable `_created` metrics in prometheus_client (brittle hack broken?)"
+        )
 
 
 def listen_manhole(
@@ -478,9 +511,10 @@ async def start(hs: "HomeServer") -> None:
     setup_sentry(hs)
     setup_sdnotify(hs)
 
-    # If background tasks are running on the main process, start collecting the
-    # phone home stats.
+    # If background tasks are running on the main process or this is the worker in
+    # charge of them, start collecting the phone home stats and shared usage metrics.
     if hs.config.worker.run_background_tasks:
+        await hs.get_common_usage_metrics_manager().setup()
         start_phone_stats_home(hs)
 
     # We now freeze all allocated objects in the hopes that (almost)
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 87f82bd9a5..8a583d3ec6 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -28,19 +28,22 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
 from synapse.events import EventBase
 from synapse.handlers.admin import ExfiltrationWriter
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.server import HomeServer
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+)
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 from synapse.types import StateMap
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.logcontext import LoggingContext
@@ -49,16 +52,17 @@ logger = logging.getLogger("synapse.app.admin_cmd")
 
 
 class AdminCmdSlavedStore(
-    SlavedReceiptsStore,
-    SlavedAccountDataStore,
-    SlavedApplicationServiceStore,
-    SlavedRegistrationStore,
     SlavedFilteringStore,
-    SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedPushRuleStore,
     SlavedEventStore,
-    BaseSlavedStore,
+    TagsWorkerStore,
+    DeviceInboxWorkerStore,
+    AccountDataWorkerStore,
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+    RegistrationWorkerStore,
+    ReceiptsWorkerStore,
     RoomWorkerStore,
 ):
     def __init__(
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 4a987fb759..5e3825fca6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -48,20 +48,12 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.client import (
     account_data,
@@ -100,8 +92,15 @@ from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import well_known_resource
 from synapse.server import HomeServer
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.appservice import (
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+)
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
 from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.lock import LockStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
@@ -110,11 +109,15 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.databases.main.presence import PresenceStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.storage.databases.main.room_batch import RoomBatchStore
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.session import SessionStore
 from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
 from synapse.storage.databases.main.user_directory import UserDirectoryStore
@@ -227,11 +230,11 @@ class GenericWorkerSlavedStore(
     UIAuthWorkerStore,
     EndToEndRoomKeyStore,
     PresenceStore,
-    SlavedDeviceInboxStore,
+    DeviceInboxWorkerStore,
     SlavedDeviceStore,
-    SlavedReceiptsStore,
     SlavedPushRuleStore,
-    SlavedAccountDataStore,
+    TagsWorkerStore,
+    AccountDataWorkerStore,
     SlavedPusherStore,
     CensorEventsStore,
     ClientIpWorkerStore,
@@ -239,19 +242,20 @@ class GenericWorkerSlavedStore(
     SlavedKeyStore,
     RoomWorkerStore,
     RoomBatchStore,
-    DirectoryStore,
-    SlavedApplicationServiceStore,
-    SlavedRegistrationStore,
-    SlavedProfileStore,
+    DirectoryWorkerStore,
+    ApplicationServiceTransactionWorkerStore,
+    ApplicationServiceWorkerStore,
+    ProfileWorkerStore,
     SlavedFilteringStore,
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
     ServerMetricsStore,
+    ReceiptsWorkerStore,
+    RegistrationWorkerStore,
     SearchStore,
     TransactionWorkerStore,
     LockStore,
     SessionStore,
-    BaseSlavedStore,
 ):
     # Properties that multiple storage classes define. Tell mypy what the
     # expected type is.
@@ -408,7 +412,11 @@ class GenericWorkerServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 logger.warning("Unsupported listener type: %s", listener.type)
 
@@ -437,6 +445,13 @@ def start(config_options: List[str]) -> None:
         "synapse.app.user_dir",
     )
 
+    if config.experimental.faster_joins_enabled:
+        raise ConfigError(
+            "You have enabled the experimental `faster_joins` config option, but it is "
+            "not compatible with worker deployments yet. Please disable `faster_joins` "
+            "or run Synapse as a single process deployment instead."
+        )
+
     synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
     synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 745e704141..883f2fd2ec 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -44,7 +44,6 @@ from synapse.app._base import (
     register_start,
 )
 from synapse.config._base import ConfigError, format_config_error
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
@@ -58,7 +57,6 @@ from synapse.http.site import SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
@@ -202,7 +200,7 @@ class SynapseHomeServer(HomeServer):
                 }
             )
 
-            if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+            if self.config.email.can_verify_email:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
                 )
@@ -220,7 +218,10 @@ class SynapseHomeServer(HomeServer):
             resources.update({"/_matrix/consent": consent_resource})
 
         if name == "federation":
-            resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+            federation_resource: Resource = TransportLayerServer(self)
+            if compress:
+                federation_resource = gz_wrap(federation_resource)
+            resources.update({FEDERATION_PREFIX: federation_resource})
 
         if name == "openid":
             resources.update(
@@ -288,16 +289,6 @@ class SynapseHomeServer(HomeServer):
                     manhole_settings=self.config.server.manhole_settings,
                     manhole_globals={"hs": self},
                 )
-            elif listener.type == "replication":
-                services = listen_tcp(
-                    listener.bind_addresses,
-                    listener.port,
-                    ReplicationStreamProtocolFactory(self),
-                )
-                for s in services:
-                    self.get_reactor().addSystemEventTrigger(
-                        "before", "shutdown", s.stopListening
-                    )
             elif listener.type == "metrics":
                 if not self.config.metrics.enable_metrics:
                     logger.warning(
@@ -305,7 +296,11 @@ class SynapseHomeServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 # this shouldn't happen, as the listener type should have been checked
                 # during parsing
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 40dbdace8e..53db1e85b3 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -32,15 +32,15 @@ logger = logging.getLogger("synapse.app.homeserver")
 _stats_process: List[Tuple[int, "resource.struct_rusage"]] = []
 
 # Gauges to expose monthly active user control metrics
-current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+current_mau_gauge = Gauge("synapse_admin_mau_current", "Current MAU")
 current_mau_by_service_gauge = Gauge(
     "synapse_admin_mau_current_mau_by_service",
     "Current MAU by service",
     ["app_service"],
 )
-max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
+max_mau_gauge = Gauge("synapse_admin_mau_max", "MAU Limit")
 registered_reserved_users_mau_gauge = Gauge(
-    "synapse_admin_mau:registered_reserved_users",
+    "synapse_admin_mau_registered_reserved_users",
     "Registered users with reserved threepids",
 )
 
@@ -51,6 +51,16 @@ async def phone_stats_home(
     stats: JsonDict,
     stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
 ) -> None:
+    """Collect usage statistics and send them to the configured endpoint.
+
+    Args:
+        hs: the HomeServer object to use for gathering usage data.
+        stats: the dict in which to store the statistics sent to the configured
+            endpoint. Mostly used in tests to figure out the data that is supposed to
+            be sent.
+        stats_process: statistics about resource usage of the process.
+    """
+
     logger.info("Gathering stats for reporting")
     now = int(hs.get_clock().time())
     # Ensure the homeserver has started.
@@ -83,6 +93,7 @@ async def phone_stats_home(
     #
 
     store = hs.get_datastores().main
+    common_metrics = await hs.get_common_usage_metrics_manager().get_metrics()
 
     stats["homeserver"] = hs.config.server.server_name
     stats["server_context"] = hs.config.server.server_context
@@ -104,7 +115,7 @@ async def phone_stats_home(
     room_count = await store.get_room_count()
     stats["total_room_count"] = room_count
 
-    stats["daily_active_users"] = await store.count_daily_users()
+    stats["daily_active_users"] = common_metrics.daily_active_users
     stats["monthly_active_users"] = await store.count_monthly_users()
     daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
     stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index df1c214462..0963fb3bb4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -53,6 +53,18 @@ sent_events_counter = Counter(
     "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
 )
 
+sent_ephemeral_counter = Counter(
+    "synapse_appservice_api_sent_ephemeral",
+    "Number of ephemeral events sent to the AS",
+    ["service"],
+)
+
+sent_todevice_counter = Counter(
+    "synapse_appservice_api_sent_todevice",
+    "Number of todevice messages sent to the AS",
+    ["service"],
+)
+
 HOUR_IN_MS = 60 * 60 * 1000
 
 
@@ -310,6 +322,8 @@ class ApplicationServiceApi(SimpleHttpClient):
                 )
             sent_transactions_counter.labels(service.id).inc()
             sent_events_counter.labels(service.id).inc(len(serialized_events))
+            sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
+            sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
             return True
         except CodeMessageException as e:
             logger.warning(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 095eca16c5..1f6362aedd 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,6 +20,7 @@ import logging
 import os
 import re
 from collections import OrderedDict
+from enum import Enum, auto
 from hashlib import sha256
 from textwrap import dedent
 from typing import (
@@ -97,16 +98,16 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
 # We split these messages out to allow packages to override with package
 # specific instructions.
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
-Please opt in or out of reporting anonymized homeserver usage statistics, by
-setting the `report_stats` key in your config file to either True or False.
+Please opt in or out of reporting homeserver usage statistics, by setting
+the `report_stats` key in your config file to either True or False.
 """
 
 MISSING_REPORT_STATS_SPIEL = """\
 We would really appreciate it if you could help our project out by reporting
-anonymized usage statistics from your homeserver. Only very basic aggregate
-data (e.g. number of users) will be reported, but it helps us to track the
-growth of the Matrix community, and helps us to make Matrix a success, as well
-as to convince other networks that they should peer with us.
+homeserver usage statistics from your homeserver. Your homeserver's server name,
+along with very basic aggregate data (e.g. number of users) will be reported. But
+it helps us to track the growth of the Matrix community, and helps us to make Matrix
+a success, as well as to convince other networks that they should peer with us.
 
 Thank you.
 """
@@ -603,25 +604,51 @@ class RootConfig:
             " may specify directories containing *.yaml files.",
         )
 
-        generate_group = parser.add_argument_group("Config generation")
-        generate_group.add_argument(
+        # we nest the mutually-exclusive group inside another group so that the help
+        # text shows them in their own group.
+        generate_mode_group = parser.add_argument_group(
+            "Config generation mode",
+        )
+        generate_mode_exclusive = generate_mode_group.add_mutually_exclusive_group()
+        generate_mode_exclusive.add_argument(
+            # hidden option to make the type and default work
+            "--generate-mode",
+            help=argparse.SUPPRESS,
+            type=_ConfigGenerateMode,
+            default=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+        )
+        generate_mode_exclusive.add_argument(
             "--generate-config",
-            action="store_true",
             help="Generate a config file, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            dest="generate_mode",
         )
-        generate_group.add_argument(
+        generate_mode_exclusive.add_argument(
             "--generate-missing-configs",
             "--generate-keys",
-            action="store_true",
             help="Generate any missing additional config files, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+            dest="generate_mode",
         )
+        generate_mode_exclusive.add_argument(
+            "--generate-missing-and-run",
+            help="Generate any missing additional config files, then run. This is the "
+            "default behaviour.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+            dest="generate_mode",
+        )
+
+        generate_group = parser.add_argument_group("Details for --generate-config")
         generate_group.add_argument(
             "-H", "--server-name", help="The server name to generate a config file for."
         )
         generate_group.add_argument(
             "--report-stats",
             action="store",
-            help="Whether the generated config reports anonymized usage statistics.",
+            help="Whether the generated config reports homeserver usage statistics.",
             choices=["yes", "no"],
         )
         generate_group.add_argument(
@@ -670,11 +697,12 @@ class RootConfig:
         config_dir_path = os.path.abspath(config_dir_path)
         data_dir_path = os.getcwd()
 
-        generate_missing_configs = config_args.generate_missing_configs
-
         obj = cls(config_files)
 
-        if config_args.generate_config:
+        if (
+            config_args.generate_mode
+            == _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT
+        ):
             if config_args.report_stats is None:
                 parser.error(
                     "Please specify either --report-stats=yes or --report-stats=no\n\n"
@@ -732,11 +760,14 @@ class RootConfig:
                     )
                     % (config_path,)
                 )
-                generate_missing_configs = True
 
         config_dict = read_config_files(config_files)
-        if generate_missing_configs:
-            obj.generate_missing_files(config_dict, config_dir_path)
+        obj.generate_missing_files(config_dict, config_dir_path)
+
+        if config_args.generate_mode in (
+            _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            _ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+        ):
             return None
 
         obj.parse_config_dict(
@@ -965,6 +996,12 @@ def read_file(file_path: Any, config_path: Iterable[str]) -> str:
         raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
 
 
+class _ConfigGenerateMode(Enum):
+    GENERATE_MISSING_AND_RUN = auto()
+    GENERATE_MISSING_AND_EXIT = auto()
+    GENERATE_EVERYTHING_AND_EXIT = auto()
+
+
 __all__ = [
     "Config",
     "RootConfig",
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index d1335e77cd..b3972ede96 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -23,7 +23,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'account_validity' section. Support for this setting has been deprecated and will be
 removed in a future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6e11fbdb9a..a3af35b7c4 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -18,7 +18,6 @@
 import email.utils
 import logging
 import os
-from enum import Enum
 from typing import Any
 
 import attr
@@ -53,7 +52,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'email' section. Support for this setting has been deprecated and will be removed in a
 future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
@@ -86,14 +85,19 @@ class EmailConfig(Config):
         if email_config is None:
             email_config = {}
 
+        self.force_tls = email_config.get("force_tls", False)
         self.email_smtp_host = email_config.get("smtp_host", "localhost")
-        self.email_smtp_port = email_config.get("smtp_port", 25)
+        self.email_smtp_port = email_config.get(
+            "smtp_port", 465 if self.force_tls else 25
+        )
         self.email_smtp_user = email_config.get("smtp_user", None)
         self.email_smtp_pass = email_config.get("smtp_pass", None)
         self.require_transport_security = email_config.get(
             "require_transport_security", False
         )
         self.enable_smtp_tls = email_config.get("enable_tls", True)
+        if self.force_tls and not self.enable_smtp_tls:
+            raise ConfigError("email.force_tls requires email.enable_tls to be true")
         if self.require_transport_security and not self.enable_smtp_tls:
             raise ConfigError(
                 "email.require_transport_security requires email.enable_tls to be true"
@@ -131,41 +135,22 @@ class EmailConfig(Config):
 
         self.email_enable_notifs = email_config.get("enable_notifs", False)
 
-        self.threepid_behaviour_email = (
-            # Have Synapse handle the email sending if account_threepid_delegates.email
-            # is not defined
-            # msisdn is currently always remote while Synapse does not support any method of
-            # sending SMS messages
-            ThreepidBehaviour.REMOTE
-            if self.root.registration.account_threepid_delegate_email
-            else ThreepidBehaviour.LOCAL
-        )
-
         if config.get("trust_identity_server_for_password_resets"):
             raise ConfigError(
                 'The config option "trust_identity_server_for_password_resets" '
-                'has been replaced by "account_threepid_delegate". '
-                "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for "
-                "details and update your config file."
+                "is no longer supported. Please remove it from the config file."
             )
 
-        self.local_threepid_handling_disabled_due_to_email_config = False
-        if (
-            self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            and email_config == {}
-        ):
-            # We cannot warn the user this has happened here
-            # Instead do so when a user attempts to reset their password
-            self.local_threepid_handling_disabled_due_to_email_config = True
-
-            self.threepid_behaviour_email = ThreepidBehaviour.OFF
+        # If we have email config settings, assume that we can verify ownership of
+        # email addresses.
+        self.can_verify_email = email_config != {}
 
         # Get lifetime of a validation token in milliseconds
         self.email_validation_token_lifetime = self.parse_duration(
             email_config.get("validation_token_lifetime", "1h")
         )
 
-        if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.can_verify_email:
             missing = []
             if not self.email_notif_from:
                 missing.append("email.notif_from")
@@ -356,18 +341,3 @@ class EmailConfig(Config):
                     "Config option email.invite_client_location must be a http or https URL",
                     path=("email", "invite_client_location"),
                 )
-
-
-class ThreepidBehaviour(Enum):
-    """
-    Enum to define the behaviour of Synapse with regards to when it contacts an identity
-    server for 3pid registration and password resets
-
-    REMOTE = use an external server to send tokens
-    LOCAL = send tokens ourselves
-    OFF = disable registration via 3pid and password resets
-    """
-
-    REMOTE = "remote"
-    LOCAL = "local"
-    OFF = "off"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index ee443cea00..702b81e636 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,9 +32,6 @@ class ExperimentalConfig(Config):
         # MSC2716 (importing historical messages)
         self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
 
-        # MSC2285 (private read receipts)
-        self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
-
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
@@ -74,6 +71,9 @@ class ExperimentalConfig(Config):
         self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
 
         # MSC2654: Unread counts
+        #
+        # Note that enabling this will result in an incorrect unread count for
+        # previously calculated push actions.
         self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
 
         # MSC2815 (allow room moderators to view redacted event content)
@@ -88,5 +88,8 @@ class ExperimentalConfig(Config):
         # MSC3715: dir param on /relations.
         self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
 
-        # MSC3827: Filtering of /publicRooms by room type
-        self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False)
+        # MSC3848: Introduce errcodes for specific event sending failures
+        self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+        # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
+        self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index cc75efdf8f..f3dc4df695 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -217,7 +217,18 @@ class KeyConfig(Config):
 
         signing_keys = self.read_file(signing_key_path, name)
         try:
-            return read_signing_keys(signing_keys.splitlines(True))
+            loaded_signing_keys = read_signing_keys(
+                [
+                    signing_key_line
+                    for signing_key_line in signing_keys.splitlines(keepends=False)
+                    if signing_key_line.strip()
+                ]
+            )
+
+            if not loaded_signing_keys:
+                raise ConfigError(f"No signing keys in file {signing_key_path}")
+
+            return loaded_signing_keys
         except Exception as e:
             raise ConfigError("Error reading %s: %s" % (name, str(e)))
 
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 3b42be5b5b..f3134834e5 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -42,6 +42,35 @@ class MetricsConfig(Config):
 
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         self.enable_metrics = config.get("enable_metrics", False)
+
+        """
+        ### `enable_legacy_metrics` (experimental)
+
+        **Experimental: this option may be removed or have its behaviour
+        changed at any time, with no notice.**
+
+        Set to `true` to publish both legacy and non-legacy Prometheus metric names,
+        or to `false` to only publish non-legacy Prometheus metric names.
+        Defaults to `true`. Has no effect if `enable_metrics` is `false`.
+
+        Legacy metric names include:
+        - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules;
+        - counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard.
+
+        These legacy metric names are unconventional and not compliant with OpenMetrics standards.
+        They are included for backwards compatibility.
+
+        Example configuration:
+        ```yaml
+        enable_legacy_metrics: false
+        ```
+
+        See https://github.com/matrix-org/synapse/issues/11106 for context.
+
+        *Since v1.67.0.*
+        """
+        self.enable_legacy_metrics = config.get("enable_legacy_metrics", True)
+
         self.report_stats = config.get("report_stats", None)
         self.report_stats_endpoint = config.get(
             "report_stats_endpoint", "https://matrix.org/report-usage-stats/push"
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fc1784efe..1ed001e105 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -21,7 +21,7 @@ from synapse.types import JsonDict
 from ._base import Config
 
 
-class RateLimitConfig:
+class RatelimitSettings:
     def __init__(
         self,
         config: Dict[str, float],
@@ -34,7 +34,7 @@ class RateLimitConfig:
 
 
 @attr.s(auto_attribs=True)
-class FederationRateLimitConfig:
+class FederationRatelimitSettings:
     window_size: int = 1000
     sleep_limit: int = 10
     sleep_delay: int = 500
@@ -50,11 +50,11 @@ class RatelimitConfig(Config):
         # Load the new-style messages config if it exists. Otherwise fall back
         # to the old method.
         if "rc_message" in config:
-            self.rc_message = RateLimitConfig(
+            self.rc_message = RatelimitSettings(
                 config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
             )
         else:
-            self.rc_message = RateLimitConfig(
+            self.rc_message = RatelimitSettings(
                 {
                     "per_second": config.get("rc_messages_per_second", 0.2),
                     "burst_count": config.get("rc_message_burst_count", 10.0),
@@ -64,9 +64,9 @@ class RatelimitConfig(Config):
         # Load the new-style federation config, if it exists. Otherwise, fall
         # back to the old method.
         if "rc_federation" in config:
-            self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
+            self.rc_federation = FederationRatelimitSettings(**config["rc_federation"])
         else:
-            self.rc_federation = FederationRateLimitConfig(
+            self.rc_federation = FederationRatelimitSettings(
                 **{
                     k: v
                     for k, v in {
@@ -80,17 +80,17 @@ class RatelimitConfig(Config):
                 }
             )
 
-        self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+        self.rc_registration = RatelimitSettings(config.get("rc_registration", {}))
 
-        self.rc_registration_token_validity = RateLimitConfig(
+        self.rc_registration_token_validity = RatelimitSettings(
             config.get("rc_registration_token_validity", {}),
             defaults={"per_second": 0.1, "burst_count": 5},
         )
 
         rc_login_config = config.get("rc_login", {})
-        self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
-        self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
-        self.rc_login_failed_attempts = RateLimitConfig(
+        self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {}))
+        self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {}))
+        self.rc_login_failed_attempts = RatelimitSettings(
             rc_login_config.get("failed_attempts", {})
         )
 
@@ -101,47 +101,54 @@ class RatelimitConfig(Config):
         rc_admin_redaction = config.get("rc_admin_redaction")
         self.rc_admin_redaction = None
         if rc_admin_redaction:
-            self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+            self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction)
 
-        self.rc_joins_local = RateLimitConfig(
+        self.rc_joins_local = RatelimitSettings(
             config.get("rc_joins", {}).get("local", {}),
             defaults={"per_second": 0.1, "burst_count": 10},
         )
-        self.rc_joins_remote = RateLimitConfig(
+        self.rc_joins_remote = RatelimitSettings(
             config.get("rc_joins", {}).get("remote", {}),
             defaults={"per_second": 0.01, "burst_count": 10},
         )
 
+        # Track the rate of joins to a given room. If there are too many, temporarily
+        # prevent local joins and remote joins via this server.
+        self.rc_joins_per_room = RatelimitSettings(
+            config.get("rc_joins_per_room", {}),
+            defaults={"per_second": 1, "burst_count": 10},
+        )
+
         # Ratelimit cross-user key requests:
         # * For local requests this is keyed by the sending device.
         # * For requests received over federation this is keyed by the origin.
         #
         # Note that this isn't exposed in the configuration as it is obscure.
-        self.rc_key_requests = RateLimitConfig(
+        self.rc_key_requests = RatelimitSettings(
             config.get("rc_key_requests", {}),
             defaults={"per_second": 20, "burst_count": 100},
         )
 
-        self.rc_3pid_validation = RateLimitConfig(
+        self.rc_3pid_validation = RatelimitSettings(
             config.get("rc_3pid_validation") or {},
             defaults={"per_second": 0.003, "burst_count": 5},
         )
 
-        self.rc_invites_per_room = RateLimitConfig(
+        self.rc_invites_per_room = RatelimitSettings(
             config.get("rc_invites", {}).get("per_room", {}),
             defaults={"per_second": 0.3, "burst_count": 10},
         )
-        self.rc_invites_per_user = RateLimitConfig(
+        self.rc_invites_per_user = RatelimitSettings(
             config.get("rc_invites", {}).get("per_user", {}),
             defaults={"per_second": 0.003, "burst_count": 5},
         )
 
-        self.rc_invites_per_issuer = RateLimitConfig(
+        self.rc_invites_per_issuer = RatelimitSettings(
             config.get("rc_invites", {}).get("per_issuer", {}),
             defaults={"per_second": 0.3, "burst_count": 10},
         )
 
-        self.rc_third_party_invite = RateLimitConfig(
+        self.rc_third_party_invite = RatelimitSettings(
             config.get("rc_third_party_invite", {}),
             defaults={
                 "per_second": self.rc_message.per_second,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index fcf99be092..df1d83dfaa 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,13 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
-from typing import Any, Optional
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import RoomCreationPreset
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config, ConfigError, read_file
 from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string_with_symbols, strtobool
 
+NO_EMAIL_DELEGATE_ERROR = """\
+Delegation of email verification to an identity server is no longer supported. To
+continue to allow users to add email addresses to their accounts, and use them for
+password resets, configure Synapse with an SMTP server via the `email` setting, and
+remove `account_threepid_delegates.email`.
+"""
+
+CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+You have configured both `registration_shared_secret` and
+`registration_shared_secret_path`. These are mutually incompatible.
+"""
+
 
 class RegistrationConfig(Config):
     section = "registration"
@@ -46,12 +58,22 @@ class RegistrationConfig(Config):
         self.enable_registration_token_3pid_bypass = config.get(
             "enable_registration_token_3pid_bypass", False
         )
+
+        # read the shared secret, either inline or from an external file
         self.registration_shared_secret = config.get("registration_shared_secret")
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path:
+            if self.registration_shared_secret:
+                raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR)
+            self.registration_shared_secret = read_file(
+                registration_shared_secret_path, ("registration_shared_secret_path",)
+            ).strip()
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
 
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
-        self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+        if "email" in account_threepid_delegates:
+            raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
         self.default_identity_server = config.get("default_identity_server")
         self.allow_guest_access = config.get("allow_guest_access", False)
@@ -210,6 +232,21 @@ class RegistrationConfig(Config):
         else:
             return ""
 
+    def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
+        # if 'registration_shared_secret_path' is specified, and the target file
+        # does not exist, generate it.
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path and not self.path_exists(
+            registration_shared_secret_path
+        ):
+            print(
+                "Generating registration shared secret file "
+                + registration_shared_secret_path
+            )
+            secret = random_string_with_symbols(50)
+            with open(registration_shared_secret_path, "w") as f:
+                f.write(f"{secret}\n")
+
     @staticmethod
     def add_arguments(parser: argparse.ArgumentParser) -> None:
         reg_group = parser.add_argument_group("registration")
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 3c69dd325f..1033496bb4 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -42,6 +42,18 @@ THUMBNAIL_SIZE_YAML = """\
         #    method: %(method)s
 """
 
+# A map from the given media type to the type of thumbnail we should generate
+# for it.
+THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
+    "image/jpeg": "jpeg",
+    "image/jpg": "jpeg",
+    "image/webp": "jpeg",
+    # Thumbnails can only be jpeg or png. We choose png thumbnails for gif
+    # because it can have transparency.
+    "image/gif": "png",
+    "image/png": "png",
+}
+
 HTTP_PROXY_SET_WARNING = """\
 The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
 
@@ -79,13 +91,22 @@ def parse_thumbnail_requirements(
         width = size["width"]
         height = size["height"]
         method = size["method"]
-        jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
-        png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
-        requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/gif", []).append(png_thumbnail)
-        requirements.setdefault("image/png", []).append(png_thumbnail)
+
+        for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
+            requirement = requirements.setdefault(format, [])
+            if thumbnail_format == "jpeg":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/jpeg")
+                )
+            elif thumbnail_format == "png":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/png")
+                )
+            else:
+                raise Exception(
+                    "Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
+                    % (format, thumbnail_format)
+                )
     return {
         media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
     }
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 085fe22c51..c91df636d9 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -36,6 +36,12 @@ from ._util import validate_config
 
 logger = logging.Logger(__name__)
 
+DIRECT_TCP_ERROR = """
+Using direct TCP replication for workers is no longer supported.
+
+Please see https://matrix-org.github.io/synapse/latest/upgrade.html#direct-tcp-replication-is-no-longer-supported-migrate-to-redis
+"""
+
 # by default, we attempt to listen on both '::' *and* '0.0.0.0' because some OSes
 # (Windows, macOS, other BSD/Linux where net.ipv6.bindv6only is set) will only listen
 # on IPv6 when '::' is set.
@@ -165,7 +171,6 @@ KNOWN_LISTENER_TYPES = {
     "http",
     "metrics",
     "manhole",
-    "replication",
 }
 
 KNOWN_RESOURCES = {
@@ -515,7 +520,9 @@ class ServerConfig(Config):
         ):
             raise ConfigError("allowed_avatar_mimetypes must be a list")
 
-        self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
+        self.listeners = [
+            parse_listener_def(i, x) for i, x in enumerate(config.get("listeners", []))
+        ]
 
         # no_tls is not really supported any more, but let's grandfather it in
         # here.
@@ -880,9 +887,12 @@ def read_gc_thresholds(
         )
 
 
-def parse_listener_def(listener: Any) -> ListenerConfig:
+def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
     """parse a listener config from the config file"""
     listener_type = listener["type"]
+    # Raise a helpful error if direct TCP replication is still configured.
+    if listener_type == "replication":
+        raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type"))
 
     port = listener.get("port")
     if not isinstance(port, int):
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 2178cbf983..a452cc3a49 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -26,7 +26,7 @@ LEGACY_TEMPLATE_DIR_WARNING = """
 This server's configuration file is using the deprecated 'template_dir' setting in the
 'sso' section. Support for this setting has been deprecated and will be removed in a
 future version of Synapse. Server admins should instead use the new
-'custom_templates_directory' setting documented here:
+'custom_template_directory' setting documented here:
 https://matrix-org.github.io/synapse/latest/templates.html
 ---------------------------------------------------------------------------------------"""
 
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f2716422b5..0fb725dd8f 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -27,7 +27,7 @@ from ._base import (
     RoutableShardedWorkerHandlingConfig,
     ShardedWorkerHandlingConfig,
 )
-from .server import ListenerConfig, parse_listener_def
+from .server import DIRECT_TCP_ERROR, ListenerConfig, parse_listener_def
 
 _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
 The send_federation config option must be disabled in the main
@@ -128,7 +128,8 @@ class WorkerConfig(Config):
             self.worker_app = None
 
         self.worker_listeners = [
-            parse_listener_def(x) for x in config.get("worker_listeners", [])
+            parse_listener_def(i, x)
+            for i, x in enumerate(config.get("worker_listeners", []))
         ]
         self.worker_daemonize = bool(config.get("worker_daemonize"))
         self.worker_pid_file = config.get("worker_pid_file")
@@ -142,7 +143,8 @@ class WorkerConfig(Config):
         self.worker_replication_host = config.get("worker_replication_host", None)
 
         # The port on the main synapse for TCP replication
-        self.worker_replication_port = config.get("worker_replication_port", None)
+        if "worker_replication_port" in config:
+            raise ConfigError(DIRECT_TCP_ERROR, ("worker_replication_port",))
 
         # The port on the main synapse for HTTP replication endpoint
         self.worker_replication_http_port = config.get("worker_replication_http_port")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 7520647d1e..23b799ac32 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import prune_event, prune_event_dict
+from synapse.logging.opentracing import trace
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ logger = logging.getLogger(__name__)
 Hasher = Callable[[bytes], "hashlib._Hash"]
 
 
+@trace
 def check_event_content_hash(
     event: EventBase, hash_algorithm: Hasher = hashlib.sha256
 ) -> bool:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 0fc2c4b27e..c7d5ef92fc 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -30,7 +30,13 @@ from synapse.api.constants import (
     JoinRules,
     Membership,
 )
-from synapse.api.errors import AuthError, EventSizeError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    EventSizeError,
+    SynapseError,
+    UnstableSpecAuthError,
+)
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
     EventFormatVersions,
@@ -103,7 +109,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
         if not is_invite_via_3pid:
             raise AuthError(403, "Event not signed by sender's server")
 
-    if event.format_version in (EventFormatVersions.V1,):
+    if event.format_version in (EventFormatVersions.ROOM_V1_V2,):
         # Only older room versions have event IDs to check.
         event_id_domain = get_domain_from_id(event.event_id)
 
@@ -291,7 +297,11 @@ def check_state_dependent_auth_rules(
         invite_level = get_named_level(auth_dict, "invite", 0)
 
         if user_level < invite_level:
-            raise AuthError(403, "You don't have permission to invite users")
+            raise UnstableSpecAuthError(
+                403,
+                "You don't have permission to invite users",
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
         else:
             logger.debug("Allowing! %s", event)
             return
@@ -474,7 +484,11 @@ def _is_membership_change_allowed(
             return
 
         if not caller_in_room:  # caller isn't joined
-            raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
+            raise UnstableSpecAuthError(
+                403,
+                "%s not in room %s." % (event.user_id, event.room_id),
+                errcode=Codes.NOT_JOINED,
+            )
 
     if Membership.INVITE == membership:
         # TODO (erikj): We should probably handle this more intelligently
@@ -484,10 +498,18 @@ def _is_membership_change_allowed(
         if target_banned:
             raise AuthError(403, "%s is banned from the room" % (target_user_id,))
         elif target_in_room:  # the target is already in the room.
-            raise AuthError(403, "%s is already in the room." % target_user_id)
+            raise UnstableSpecAuthError(
+                403,
+                "%s is already in the room." % target_user_id,
+                errcode=Codes.ALREADY_JOINED,
+            )
         else:
             if user_level < invite_level:
-                raise AuthError(403, "You don't have permission to invite users")
+                raise UnstableSpecAuthError(
+                    403,
+                    "You don't have permission to invite users",
+                    errcode=Codes.INSUFFICIENT_POWER,
+                )
     elif Membership.JOIN == membership:
         # Joins are valid iff caller == target and:
         # * They are not banned.
@@ -549,15 +571,27 @@ def _is_membership_change_allowed(
     elif Membership.LEAVE == membership:
         # TODO (erikj): Implement kicks.
         if target_banned and user_level < ban_level:
-            raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
+            raise UnstableSpecAuthError(
+                403,
+                "You cannot unban user %s." % (target_user_id,),
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
         elif target_user_id != event.user_id:
             kick_level = get_named_level(auth_events, "kick", 50)
 
             if user_level < kick_level or user_level <= target_level:
-                raise AuthError(403, "You cannot kick user %s." % target_user_id)
+                raise UnstableSpecAuthError(
+                    403,
+                    "You cannot kick user %s." % target_user_id,
+                    errcode=Codes.INSUFFICIENT_POWER,
+                )
     elif Membership.BAN == membership:
         if user_level < ban_level or user_level <= target_level:
-            raise AuthError(403, "You don't have permission to ban")
+            raise UnstableSpecAuthError(
+                403,
+                "You don't have permission to ban",
+                errcode=Codes.INSUFFICIENT_POWER,
+            )
     elif room_version.msc2403_knocking and Membership.KNOCK == membership:
         if join_rule != JoinRules.KNOCK and (
             not room_version.msc3787_knock_restricted_join_rule
@@ -567,7 +601,11 @@ def _is_membership_change_allowed(
         elif target_user_id != event.user_id:
             raise AuthError(403, "You cannot knock for other users")
         elif target_in_room:
-            raise AuthError(403, "You cannot knock on a room you are already in")
+            raise UnstableSpecAuthError(
+                403,
+                "You cannot knock on a room you are already in",
+                errcode=Codes.ALREADY_JOINED,
+            )
         elif caller_invited:
             raise AuthError(403, "You are already invited to this room")
         elif target_banned:
@@ -638,10 +676,11 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
     user_level = get_user_power_level(event.user_id, auth_events)
 
     if user_level < send_level:
-        raise AuthError(
+        raise UnstableSpecAuthError(
             403,
             "You don't have permission to post that to the room. "
             + "user_level (%d) < send_level (%d)" % (user_level, send_level),
+            errcode=Codes.INSUFFICIENT_POWER,
         )
 
     # Check state_key
@@ -677,7 +716,7 @@ def check_redaction(
     if user_level >= redact_level:
         return False
 
-    if room_version_obj.event_format == EventFormatVersions.V1:
+    if room_version_obj.event_format == EventFormatVersions.ROOM_V1_V2:
         redacter_domain = get_domain_from_id(event.event_id)
         if not isinstance(event.redacts, str):
             return False
@@ -716,9 +755,10 @@ def check_historical(
     historical_level = get_named_level(auth_events, "historical", 100)
 
     if user_level < historical_level:
-        raise AuthError(
+        raise UnstableSpecAuthError(
             403,
             'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
+            errcode=Codes.INSUFFICIENT_POWER,
         )
 
 
@@ -740,6 +780,32 @@ def _check_power_levels(
         except Exception:
             raise SynapseError(400, "Not a valid power level: %s" % (v,))
 
+    # Reject events with stringy power levels if required by room version
+    if (
+        event.type == EventTypes.PowerLevels
+        and room_version_obj.msc3667_int_only_power_levels
+    ):
+        for k, v in event.content.items():
+            if k in {
+                "users_default",
+                "events_default",
+                "state_default",
+                "ban",
+                "redact",
+                "kick",
+                "invite",
+            }:
+                if not isinstance(v, int):
+                    raise SynapseError(400, f"{v!r} must be an integer.")
+            if k in {"events", "notifications", "users"}:
+                if not isinstance(v, dict) or not all(
+                    isinstance(v, int) for v in v.values()
+                ):
+                    raise SynapseError(
+                        400,
+                        f"{v!r} must be a dict wherein all the values are integers.",
+                    )
+
     key = (event.type, event.state_key)
     current_state = auth_events.get(key)
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 39ad2793d9..b2c9119fd0 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -442,7 +442,7 @@ class EventBase(metaclass=abc.ABCMeta):
 
 
 class FrozenEvent(EventBase):
-    format_version = EventFormatVersions.V1  # All events of this type are V1
+    format_version = EventFormatVersions.ROOM_V1_V2  # All events of this type are V1
 
     def __init__(
         self,
@@ -490,7 +490,7 @@ class FrozenEvent(EventBase):
 
 
 class FrozenEventV2(EventBase):
-    format_version = EventFormatVersions.V2  # All events of this type are V2
+    format_version = EventFormatVersions.ROOM_V3  # All events of this type are V2
 
     def __init__(
         self,
@@ -567,7 +567,7 @@ class FrozenEventV2(EventBase):
 class FrozenEventV3(FrozenEventV2):
     """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
 
-    format_version = EventFormatVersions.V3  # All events of this type are V3
+    format_version = EventFormatVersions.ROOM_V4_PLUS  # All events of this type are V3
 
     @property
     def event_id(self) -> str:
@@ -597,11 +597,11 @@ def _event_type_from_format_version(
         `FrozenEvent`
     """
 
-    if format_version == EventFormatVersions.V1:
+    if format_version == EventFormatVersions.ROOM_V1_V2:
         return FrozenEvent
-    elif format_version == EventFormatVersions.V2:
+    elif format_version == EventFormatVersions.ROOM_V3:
         return FrozenEventV2
-    elif format_version == EventFormatVersions.V3:
+    elif format_version == EventFormatVersions.ROOM_V4_PLUS:
         return FrozenEventV3
     else:
         raise Exception("No event format %r" % (format_version,))
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..746bd3978d 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -24,9 +24,11 @@ from synapse.api.room_versions import (
     RoomVersion,
 )
 from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
+from synapse.storage.state import StateFilter
 from synapse.types import EventID, JsonDict
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
@@ -120,8 +122,12 @@ class EventBuilder:
             The signed and hashed event.
         """
         if auth_event_ids is None:
-            state_ids = await self._state.get_current_state_ids(
-                self.room_id, prev_event_ids
+            state_ids = await self._state.compute_state_after_events(
+                self.room_id,
+                prev_event_ids,
+                state_filter=StateFilter.from_types(
+                    auth_types_for_event(self.room_version, self)
+                ),
             )
             auth_event_ids = self._event_auth_handler.compute_auth_events(
                 self, state_ids
@@ -131,7 +137,7 @@ class EventBuilder:
         # The types of auth/prev events changes between event versions.
         prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
-        if format_version == EventFormatVersions.V1:
+        if format_version == EventFormatVersions.ROOM_V1_V2:
             auth_events = await self._store.add_event_hashes(auth_event_ids)
             prev_events = await self._store.add_event_hashes(prev_event_ids)
         else:
@@ -247,7 +253,7 @@ def create_local_event_from_event_dict(
 
     time_now = int(clock.time_msec())
 
-    if format_version == EventFormatVersions.V1:
+    if format_version == EventFormatVersions.ROOM_V1_V2:
         event_dict["event_id"] = _create_event_id(clock, hostname)
 
     event_dict["origin"] = hostname
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index b700cbbfa1..d3c8083e4a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,11 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
 from frozendict import frozendict
-from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
@@ -33,7 +32,7 @@ class EventContext:
     Holds information relevant to persisting an event
 
     Attributes:
-        rejected: A rejection reason if the event was rejected, else False
+        rejected: A rejection reason if the event was rejected, else None
 
         _state_group: The ID of the state group for this event. Note that state events
             are persisted with a state group which includes the new event, so this is
@@ -85,7 +84,7 @@ class EventContext:
     """
 
     _storage: "StorageControllers"
-    rejected: Union[Literal[False], str] = False
+    rejected: Optional[str] = None
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
     _state_delta_due_to_event: Optional[StateMap[str]] = None
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 4a3bfb38f1..623a2c71ea 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -32,6 +32,7 @@ from typing_extensions import Literal
 
 import synapse
 from synapse.api.errors import Codes
+from synapse.logging.opentracing import trace
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
@@ -378,6 +379,7 @@ class SpamChecker:
         if check_media_file_for_spam is not None:
             self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
 
+    @trace
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
     ) -> Union[Tuple[Codes, JsonDict], str]:
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ac91c5eb57..71853caad8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -161,7 +161,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
     elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
         add_fields(EventContentFields.MSC2716_BATCH_ID)
     elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
-        add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
+        add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 27c8beba25..a6f0104396 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -45,7 +45,7 @@ class EventValidator:
         """
         self.validate_builder(event)
 
-        if event.format_version == EventFormatVersions.V1:
+        if event.format_version == EventFormatVersions.ROOM_V1_V2:
             EventID.from_string(event.event_id)
 
         required = [
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2522bf78fc..abe2c1971a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
+from synapse.logging.opentracing import log_kv, trace
 from synapse.types import JsonDict, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -55,6 +56,7 @@ class FederationBase:
         self._clock = hs.get_clock()
         self._storage_controllers = hs.get_storage_controllers()
 
+    @trace
     async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
     ) -> EventBase:
@@ -97,17 +99,36 @@ class FederationBase:
                     "Event %s seems to have been redacted; using our redacted copy",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event seems to have been redacted; using our redacted copy",
+                        "event_id": pdu.event_id,
+                    }
+                )
             else:
                 logger.warning(
                     "Event %s content has been tampered, redacting",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event content has been tampered, redacting",
+                        "event_id": pdu.event_id,
+                    }
+                )
             return redacted_event
 
         spam_check = await self.spam_checker.check_event_for_spam(pdu)
 
         if spam_check != self.spam_checker.NOT_SPAM:
             logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
+            log_kv(
+                {
+                    "message": "Event contains spam, redacting (to save disk space) "
+                    "as well as soft-failing (to stop using the event in prev_events)",
+                    "event_id": pdu.event_id,
+                }
+            )
             # we redact (to save disk space) as well as soft-failing (to stop
             # using the event in prev_events).
             redacted_event = prune_event(pdu)
@@ -117,6 +138,7 @@ class FederationBase:
         return pdu
 
 
+@trace
 async def _check_sigs_on_pdu(
     keyring: Keyring, room_version: RoomVersion, pdu: EventBase
 ) -> None:
@@ -172,7 +194,7 @@ async def _check_sigs_on_pdu(
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if room_version.event_format == EventFormatVersions.V1:
+    if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
         event_domain = get_domain_from_id(pdu.event_id)
         if event_domain != sender_domain:
             try:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 66e6305562..4a4289ee7c 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -53,7 +53,7 @@ from synapse.api.room_versions import (
     RoomVersion,
     RoomVersions,
 )
-from synapse.events import EventBase, builder
+from synapse.events import EventBase, builder, make_event_from_dict
 from synapse.federation.federation_base import (
     FederationBase,
     InvalidEventSignatureError,
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
 )
 from synapse.federation.transport.client import SendJoinResponse
 from synapse.http.types import QueryParams
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -217,7 +218,7 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: int
+        self, destination: str, content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
@@ -233,6 +234,8 @@ class FederationClient(FederationBase):
             destination, content, timeout
         )
 
+    @trace
+    @tag_args
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> Optional[List[EventBase]]:
@@ -299,7 +302,8 @@ class FederationClient(FederationBase):
                 moving to the next destination. None indicates no timeout.
 
         Returns:
-            The requested PDU, or None if we were unable to find it.
+            A copy of the requested PDU that is safe to modify, or None if we
+            were unable to find it.
 
         Raises:
             SynapseError, NotRetryingDestination, FederationDeniedError
@@ -309,7 +313,7 @@ class FederationClient(FederationBase):
         )
 
         logger.debug(
-            "retrieved event id %s from %s: %r",
+            "get_pdu_from_destination_raw: retrieved event id %s from %s: %r",
             event_id,
             destination,
             transaction_data,
@@ -334,6 +338,8 @@ class FederationClient(FederationBase):
 
         return None
 
+    @trace
+    @tag_args
     async def get_pdu(
         self,
         destinations: Iterable[str],
@@ -358,55 +364,95 @@ class FederationClient(FederationBase):
             The requested PDU, or None if we were unable to find it.
         """
 
+        logger.debug(
+            "get_pdu: event_id=%s from destinations=%s", event_id, destinations
+        )
+
         # TODO: Rate limit the number of times we try and get the same event.
 
-        ev = self._get_pdu_cache.get(event_id)
-        if ev:
-            return ev
+        # We might need the same event multiple times in quick succession (before
+        # it gets persisted to the database), so we cache the results of the lookup.
+        # Note that this is separate to the regular get_event cache which caches
+        # events once they have been persisted.
+        event = self._get_pdu_cache.get(event_id)
+
+        # If we don't see the event in the cache, go try to fetch it from the
+        # provided remote federated destinations
+        if not event:
+            pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
+            for destination in destinations:
+                now = self._clock.time_msec()
+                last_attempt = pdu_attempts.get(destination, 0)
+                if last_attempt + PDU_RETRY_TIME_MS > now:
+                    logger.debug(
+                        "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)",
+                        destination,
+                        last_attempt,
+                        PDU_RETRY_TIME_MS,
+                        now,
+                    )
+                    continue
+
+                try:
+                    event = await self.get_pdu_from_destination_raw(
+                        destination=destination,
+                        event_id=event_id,
+                        room_version=room_version,
+                        timeout=timeout,
+                    )
 
-        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+                    pdu_attempts[destination] = now
 
-        signed_pdu = None
-        for destination in destinations:
-            now = self._clock.time_msec()
-            last_attempt = pdu_attempts.get(destination, 0)
-            if last_attempt + PDU_RETRY_TIME_MS > now:
-                continue
+                    if event:
+                        # Prime the cache
+                        self._get_pdu_cache[event.event_id] = event
 
-            try:
-                signed_pdu = await self.get_pdu_from_destination_raw(
-                    destination=destination,
-                    event_id=event_id,
-                    room_version=room_version,
-                    timeout=timeout,
-                )
+                        # Now that we have an event, we can break out of this
+                        # loop and stop asking other destinations.
+                        break
 
-                pdu_attempts[destination] = now
-
-            except SynapseError as e:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
-            except NotRetryingDestination as e:
-                logger.info(str(e))
-                continue
-            except FederationDeniedError as e:
-                logger.info(str(e))
-                continue
-            except Exception as e:
-                pdu_attempts[destination] = now
+                except SynapseError as e:
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
+                except NotRetryingDestination as e:
+                    logger.info(str(e))
+                    continue
+                except FederationDeniedError as e:
+                    logger.info(str(e))
+                    continue
+                except Exception as e:
+                    pdu_attempts[destination] = now
+
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
 
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
+        if not event:
+            return None
 
-        if signed_pdu:
-            self._get_pdu_cache[event_id] = signed_pdu
+        # `event` now refers to an object stored in `get_pdu_cache`. Our
+        # callers may need to modify the returned object (eg to set
+        # `event.internal_metadata.outlier = true`), so we return a copy
+        # rather than the original object.
+        event_copy = make_event_from_dict(
+            event.get_pdu_json(),
+            event.room_version,
+        )
 
-        return signed_pdu
+        return event_copy
 
+    @trace
+    @tag_args
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> Tuple[List[str], List[str]]:
@@ -426,6 +472,23 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids",
+            str(state_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids.length",
+            str(len(state_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids",
+            str(auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids.length",
+            str(len(auth_event_ids)),
+        )
+
         if not isinstance(state_event_ids, list) or not isinstance(
             auth_event_ids, list
         ):
@@ -433,6 +496,8 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    @trace
+    @tag_args
     async def get_room_state(
         self,
         destination: str,
@@ -492,6 +557,7 @@ class FederationClient(FederationBase):
 
         return valid_state_events, valid_auth_events
 
+    @trace
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
@@ -521,11 +587,15 @@ class FederationClient(FederationBase):
         Returns:
             A list of PDUs that have valid signatures and hashes.
         """
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "pdus.length",
+            str(len(pdus)),
+        )
 
         # We limit how many PDUs we check at once, as if we try to do hundreds
         # of thousands of PDUs at once we see large memory spikes.
 
-        valid_pdus = []
+        valid_pdus: List[EventBase] = []
 
         async def _execute(pdu: EventBase) -> None:
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
@@ -541,6 +611,8 @@ class FederationClient(FederationBase):
 
         return valid_pdus
 
+    @trace
+    @tag_args
     async def _check_sigs_and_hash_and_fetch_one(
         self,
         pdu: EventBase,
@@ -573,16 +645,27 @@ class FederationClient(FederationBase):
         except InvalidEventSignatureError as e:
             logger.warning(
                 "Signature on retrieved event %s was invalid (%s). "
-                "Checking local store/orgin server",
+                "Checking local store/origin server",
                 pdu.event_id,
                 e,
             )
+            log_kv(
+                {
+                    "message": "Signature on retrieved event was invalid. "
+                    "Checking local store/origin server",
+                    "event_id": pdu.event_id,
+                    "InvalidEventSignatureError": e,
+                }
+            )
 
         # Check local db.
         res = await self.store.get_event(
             pdu.event_id, allow_rejected=True, allow_none=True
         )
 
+        # If the PDU fails its signature check and we don't have it in our
+        # database, we then request it from sender's server (if that is not the
+        # same as `origin`).
         pdu_origin = get_domain_from_id(pdu.sender)
         if not res and pdu_origin != origin:
             try:
@@ -686,6 +769,12 @@ class FederationClient(FederationBase):
         if failover_errcodes is None:
             failover_errcodes = ()
 
+        if not destinations:
+            # Give a bit of a clearer message if no servers were specified at all.
+            raise SynapseError(
+                502, f"Failed to {description} via any server: No servers specified."
+            )
+
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -735,7 +824,7 @@ class FederationClient(FederationBase):
                     "Failed to %s via %s", description, destination, exc_info=True
                 )
 
-        raise SynapseError(502, "Failed to %s via any server" % (description,))
+        raise SynapseError(502, f"Failed to {description} via any server")
 
     async def make_membership_event(
         self,
@@ -1101,7 +1190,7 @@ class FederationClient(FederationBase):
             # Otherwise, consider it a legitimate error and raise.
             err = e.to_synapse_error()
             if self._is_unknown_endpoint(e, err):
-                if room_version.event_format != EventFormatVersions.V1:
+                if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
                     raise SynapseError(
                         400,
                         "User's homeserver does not support this room version",
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5dfdc86740..3bf84cf625 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -61,7 +61,12 @@ from synapse.logging.context import (
     nested_logging_context,
     run_in_background,
 )
-from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.opentracing import (
+    log_kv,
+    start_active_span_from_edu,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
@@ -118,6 +123,7 @@ class FederationServer(FederationBase):
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
+        self._room_member_handler = hs.get_room_member_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -468,7 +474,7 @@ class FederationServer(FederationBase):
                     )
                     for pdu in pdus_by_room[room_id]:
                         event_id = pdu.event_id
-                        pdu_results[event_id] = e.error_dict()
+                        pdu_results[event_id] = e.error_dict(self.hs.config)
                     return
 
                 for pdu in pdus_by_room[room_id]:
@@ -546,6 +552,8 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
+    @trace
+    @tag_args
     async def on_state_ids_request(
         self, origin: str, room_id: str, event_id: str
     ) -> Tuple[int, JsonDict]:
@@ -568,6 +576,8 @@ class FederationServer(FederationBase):
 
         return 200, resp
 
+    @trace
+    @tag_args
     async def _on_state_ids_request_compute(
         self, room_id: str, event_id: str
     ) -> JsonDict:
@@ -621,6 +631,15 @@ class FederationServer(FederationBase):
             )
             raise IncompatibleRoomVersionError(room_version=room_version)
 
+        # Refuse the request if that room has seen too many joins recently.
+        # This is in addition to the HS-level rate limiting applied by
+        # BaseFederationServlet.
+        # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
         pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
         return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
@@ -655,6 +674,12 @@ class FederationServer(FederationBase):
         room_id: str,
         caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
+
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
@@ -738,6 +763,17 @@ class FederationServer(FederationBase):
             The partial knock event.
         """
         origin_host, _ = parse_server_name(origin)
+
+        if await self.store.is_partial_state_room(room_id):
+            # Before we do anything: check if the room is partial-stated.
+            # Note that at the time this check was added, `on_make_knock_request` would
+            # block due to https://github.com/matrix-org/synapse/issues/12997.
+            raise SynapseError(
+                404,
+                "Unable to handle /make_knock right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         await self.check_server_matches_acl(origin_host, room_id)
 
         room_version = await self.store.get_room_version(room_id)
@@ -827,8 +863,25 @@ class FederationServer(FederationBase):
                 Codes.BAD_JSON,
             )
 
+        # Note that get_room_version throws if the room does not exist here.
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /send_join, /send_knock or /send_leave.
+            # This is because we will not be able to provide the server list (for partial
+            # joins) or the full state (for full joins).
+            # Return a 404 as we would if we weren't in the room at all.
+            logger.info(
+                f"Rejecting /send_{membership_type} to %s because it's a partial state room",
+                room_id,
+            )
+            raise SynapseError(
+                404,
+                f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
             raise SynapseError(
                 403,
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 99a794c042..a6cb3ba58f 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -62,12 +62,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 sent_pdus_destination_dist_count = Counter(
-    "synapse_federation_client_sent_pdu_destinations:count",
+    "synapse_federation_client_sent_pdu_destinations_count",
     "Number of PDUs queued for sending to one or more destinations",
 )
 
 sent_pdus_destination_dist_total = Counter(
-    "synapse_federation_client_sent_pdu_destinations:total",
+    "synapse_federation_client_sent_pdu_destinations",
     "Total number of PDUs queued for sending across all destinations",
 )
 
@@ -351,7 +351,11 @@ class FederationSender(AbstractFederationSender):
             self._is_processing = True
             while True:
                 last_token = await self.store.get_federation_out_pos("events")
-                next_token, events = await self.store.get_all_new_events_stream(
+                (
+                    next_token,
+                    events,
+                    event_to_received_ts,
+                ) = await self.store.get_all_new_events_stream(
                     last_token, self._last_poked_id, limit=100
                 )
 
@@ -437,6 +441,19 @@ class FederationSender(AbstractFederationSender):
                             destinations = await self._external_cache.get(
                                 "get_joined_hosts", str(sg)
                             )
+                            if destinations is None:
+                                # Add logging to help track down #13444
+                                logger.info(
+                                    "Unexpectedly did not have cached destinations for %s / %s",
+                                    sg,
+                                    event.event_id,
+                                )
+                        else:
+                            # Add logging to help track down #13444
+                            logger.info(
+                                "Unexpectedly did not have cached prev group for %s",
+                                event.event_id,
+                            )
 
                     if destinations is None:
                         try:
@@ -476,7 +493,7 @@ class FederationSender(AbstractFederationSender):
                         await self._send_pdu(event, sharded_destinations)
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "federation_sender"
@@ -509,7 +526,7 @@ class FederationSender(AbstractFederationSender):
 
                 if events:
                     now = self.clock.time_msec()
-                    ts = await self.store.get_received_ts(events[-1].event_id)
+                    ts = event_to_received_ts[events[-1].event_id]
                     assert ts is not None
 
                     synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9e84bd677e..32074b8ca6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -619,7 +619,7 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys(
-        self, destination: str, query_content: JsonDict, timeout: int
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 84100a5a52..1db8009d6c 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
-from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
+from synapse.http.server import HttpServer, ServletCallback
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import run_in_background
@@ -34,6 +34,7 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.types import JsonDict
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import parse_and_validate_server_name
 
@@ -309,7 +310,7 @@ class BaseFederationServlet:
                 raise
 
             # update the active opentracing span with the authenticated entity
-            set_tag("authenticated_entity", origin)
+            set_tag("authenticated_entity", str(origin))
 
             # if the origin is authenticated and whitelisted, use its span context
             # as the parent.
@@ -375,7 +376,7 @@ class BaseFederationServlet:
             if code is None:
                 continue
 
-            if is_method_cancellable(code):
+            if is_function_cancellable(code):
                 # The wrapper added by `self._wrap` will inherit the cancellable flag,
                 # but the wrapper itself does not support cancellation yet.
                 # Once resolved, the cancellation tests in
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index f7884bfbe0..6bb4659c4c 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -549,8 +549,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
 
 
 class FederationGetMissingEventsServlet(BaseFederationServerServlet):
-    # TODO(paul): Why does this path alone end with "/?" optional?
-    PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
+    PATH = "/get_missing_events/(?P<room_id>[^/]*)"
 
     async def on_POST(
         self,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index d4fe7df533..cf9f19608a 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -70,6 +70,7 @@ class AdminHandler:
             "appservice_id",
             "consent_server_notice_sent",
             "consent_version",
+            "consent_ts",
             "user_type",
             "is_guest",
         }
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 814553e098..203b62e015 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -104,14 +104,15 @@ class ApplicationServicesHandler:
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                limit = 100
                 upper_bound = -1
                 while upper_bound < self.current_max:
+                    last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
                         events,
-                    ) = await self.store.get_new_events_for_appservice(
-                        self.current_max, limit
+                        event_to_received_ts,
+                    ) = await self.store.get_all_new_events_stream(
+                        last_token, self.current_max, limit=100, get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
@@ -150,7 +151,7 @@ class ApplicationServicesHandler:
                             )
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag_by_event.labels(
@@ -187,7 +188,7 @@ class ApplicationServicesHandler:
 
                     if events:
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(events[-1].event_id)
+                        ts = event_to_received_ts[events[-1].event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3d83236b0c..0327fc57a4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -280,7 +280,7 @@ class AuthHandler:
         that it isn't stolen by re-authenticating them.
 
         Args:
-            requester: The user, as given by the access token
+            requester: The user making the request, according to the access token.
 
             request: The request sent by the client.
 
@@ -565,7 +565,7 @@ class AuthHandler:
             except LoginError as e:
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
-                errordict = e.error_dict()
+                errordict = e.error_dict(self.hs.config)
 
         creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
@@ -1435,20 +1435,25 @@ class AuthHandler:
             access_token: access token to be deleted
 
         """
-        user_info = await self.auth.get_user_by_access_token(access_token)
+        token = await self.store.get_user_by_access_token(access_token)
+        if not token:
+            # At this point, the token should already have been fetched once by
+            # the caller, so this should not happen, unless of a race condition
+            # between two delete requests
+            raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
         await self.store.delete_access_token(access_token)
 
         # see if any modules want to know about this
         await self.password_auth_provider.on_logged_out(
-            user_id=user_info.user_id,
-            device_id=user_info.device_id,
+            user_id=token.user_id,
+            device_id=token.device_id,
             access_token=access_token,
         )
 
         # delete pushers associated with this access token
-        if user_info.token_id is not None:
+        if token.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                user_info.user_id, (user_info.token_id,)
+                token.user_id, (token.token_id,)
             )
 
     async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..901e2310b7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,13 +45,13 @@ from synapse.types import (
     JsonDict,
     StreamKeyType,
     StreamToken,
-    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
 from synapse.util.metrics import measure_func
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,11 +119,12 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
+    @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
     ) -> Collection[str]:
@@ -162,6 +164,7 @@ class DeviceWorkerHandler:
 
     @trace
     @measure_func("device.get_user_ids_changed")
+    @cancellable
     async def get_user_ids_changed(
         self, user_id: str, from_token: StreamToken
     ) -> JsonDict:
@@ -170,7 +173,7 @@ class DeviceWorkerHandler:
         """
 
         set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
+        set_tag("from_token", str(from_token))
         now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -309,6 +312,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -319,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler):
             self.device_list_updater.incoming_device_list_update,
         )
 
-        hs.get_distributor().observe("user_left_room", self.user_left_room)
-
         # Whether `_handle_new_device_update_async` is currently processing.
         self._handle_new_device_update_is_processing = False
 
@@ -564,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler):
             StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
         )
 
-    async def user_left_room(self, user: UserID, room_id: str) -> None:
-        user_id = user.to_string()
-        room_ids = await self.store.get_rooms_for_user(user_id)
-        if not room_ids:
-            # We no longer share rooms with this user, so we'll no longer
-            # receive device updates. Mark this in DB.
-            await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
     async def store_dehydrated_device(
         self,
         user_id: str,
@@ -693,8 +687,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
@@ -747,7 +744,13 @@ def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
-    device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+    device.update(
+        {
+            "last_seen_user_agent": ip.get("user_agent"),
+            "last_seen_ts": ip.get("last_seen"),
+            "last_seen_ip": ip.get("ip"),
+        }
+    )
 
 
 class DeviceListUpdater:
@@ -795,7 +798,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 09a7a4b238..7127d5aefc 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
 from synapse.appservice import ApplicationService
 from synapse.module_api import NOT_SPAM
 from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -83,8 +83,9 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.store.get_users_in_room(room_id)
-            servers = {get_domain_from_id(u) for u in users}
+            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+                room_id
+            )
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -133,7 +134,7 @@ class DirectoryHandler:
         else:
             # Server admins are not subject to the same constraints as normal
             # users when creating an alias (e.g. being in the room).
-            is_admin = await self.auth.is_server_admin(requester.user)
+            is_admin = await self.auth.is_server_admin(requester)
 
             if (self.require_membership and check_membership) and not is_admin:
                 rooms_for_user = await self.store.get_rooms_for_user(user_id)
@@ -197,7 +198,7 @@ class DirectoryHandler:
         user_id = requester.user.to_string()
 
         try:
-            can_delete = await self._user_can_delete_alias(room_alias, user_id)
+            can_delete = await self._user_can_delete_alias(room_alias, requester)
         except StoreError as e:
             if e.code == 404:
                 raise NotFoundError("Unknown room alias")
@@ -287,8 +288,9 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        users = await self.store.get_users_in_room(room_id)
-        extra_servers = {get_domain_from_id(u) for u in users}
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
@@ -400,7 +402,9 @@ class DirectoryHandler:
         # either no interested services, or no service with an exclusive lock
         return True
 
-    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
+    async def _user_can_delete_alias(
+        self, alias: RoomAlias, requester: Requester
+    ) -> bool:
         """Determine whether a user can delete an alias.
 
         One of the following must be true:
@@ -413,7 +417,7 @@ class DirectoryHandler:
         """
         creator = await self.store.get_room_alias_creator(alias.to_string())
 
-        if creator == user_id:
+        if creator == requester.user.to_string():
             return True
 
         # Resolve the alias to the corresponding room.
@@ -422,9 +426,7 @@ class DirectoryHandler:
         if not room_id:
             return False
 
-        return await self.auth.check_can_change_room_list(
-            room_id, UserID.from_string(user_id)
-        )
+        return await self.auth.check_can_change_room_list(room_id, requester)
 
     async def edit_published_room_list(
         self, requester: Requester, room_id: str, visibility: str
@@ -463,7 +465,7 @@ class DirectoryHandler:
             raise SynapseError(400, "Unknown room")
 
         can_change_room_list = await self.auth.check_can_change_room_list(
-            room_id, requester.user
+            room_id, requester
         )
         if not can_change_room_list:
             raise AuthError(
@@ -528,10 +530,8 @@ class DirectoryHandler:
         Get a list of the aliases that currently point to this room on this server
         """
         # allow access to server admins and current members of the room
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
         if not is_admin:
-            await self.auth.check_user_in_room_or_world_readable(
-                room_id, requester.user.to_string()
-            )
+            await self.auth.check_user_in_room_or_world_readable(room_id, requester)
 
         return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..8eed63ccf3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -37,7 +37,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -91,8 +92,13 @@ class E2eKeysHandler:
         )
 
     @trace
+    @cancellable
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +126,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,8 +140,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -171,6 +175,32 @@ class E2eKeysHandler:
                     user_ids_not_in_cache,
                     remote_results,
                 ) = await self.store.get_user_devices_from_cache(query_list)
+
+                # Check that the homeserver still shares a room with all cached users.
+                # Note that this check may be slightly racy when a remote user leaves a
+                # room after we have fetched their cached device list. In the worst case
+                # we will do extra federation queries for devices that we had cached.
+                cached_users = set(remote_results.keys())
+                valid_cached_users = (
+                    await self.store.get_users_server_still_shares_room_with(
+                        remote_results.keys()
+                    )
+                )
+                invalid_cached_users = cached_users - valid_cached_users
+                if invalid_cached_users:
+                    # Fix up results. If we get here, there is either a bug in device
+                    # list tracking, or we hit the race mentioned above.
+                    user_ids_not_in_cache.update(invalid_cached_users)
+                    for invalid_user_id in invalid_cached_users:
+                        remote_results.pop(invalid_user_id)
+                    # This log message may be removed if it turns out it's almost
+                    # entirely triggered by races.
+                    logger.error(
+                        "Devices for %s were cached, but the server no longer shares "
+                        "any rooms with them. The cached device lists are stale.",
+                        invalid_cached_users,
+                    )
+
                 for user_id, devices in remote_results.items():
                     user_devices = results.setdefault(user_id, {})
                     for device_id, device in devices.items():
@@ -206,22 +236,26 @@ class E2eKeysHandler:
                     r[user_id] = remote_queries[user_id]
 
             # Now fetch any devices that we don't have in our cache
+            # TODO It might make sense to propagate cancellations into the
+            #      deferreds which are querying remote homeservers.
             await make_deferred_yieldable(
-                defer.gatherResults(
-                    [
-                        run_in_background(
-                            self._query_devices_for_destination,
-                            results,
-                            cross_signing_keys,
-                            failures,
-                            destination,
-                            queries,
-                            timeout,
-                        )
-                        for destination, queries in remote_queries_not_in_cache.items()
-                    ],
-                    consumeErrors=True,
-                ).addErrback(unwrapFirstError)
+                delay_cancellation(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self._query_devices_for_destination,
+                                results,
+                                cross_signing_keys,
+                                failures,
+                                destination,
+                                queries,
+                                timeout,
+                            )
+                            for destination, queries in remote_queries_not_in_cache.items()
+                        ],
+                        consumeErrors=True,
+                    ).addErrback(unwrapFirstError)
+                )
             )
 
             ret = {"device_keys": results, "failures": failures}
@@ -341,10 +375,11 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
+    @cancellable
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
@@ -391,8 +426,9 @@ class E2eKeysHandler:
         }
 
     @trace
+    @cancellable
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -403,7 +439,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -461,7 +497,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -475,8 +511,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -506,7 +542,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -609,7 +645,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 446f509bdc..28dc08c22a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
 
 from typing_extensions import Literal
 
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
                 user_id, version, room_id, session_id
             )
 
-            log_kv(results)
+            log_kv(cast(JsonDict, results))
             return results
 
     @trace
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a2dd9c7efa..c3ddc5d182 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -129,12 +129,9 @@ class EventAuthHandler:
         else:
             users = {}
 
-        # Find the user with the highest power level.
-        users_in_room = await self._store.get_users_in_room(room_id)
-        # Only interested in local users.
-        local_users_in_room = [
-            u for u in users_in_room if get_domain_from_id(u) == self._server_name
-        ]
+        # Find the user with the highest power level (only interested in local
+        # users).
+        local_users_in_room = await self._store.get_local_users_in_room(room_id)
         chosen_user = max(
             local_users_in_room,
             key=lambda user: users.get(user, users_default_level),
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ac13340d3a..949b69cb41 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -151,7 +151,7 @@ class EventHandler:
         """Retrieve a single specified event.
 
         Args:
-            user: The user requesting the event
+            user: The local user requesting the event
             room_id: The expected room id. We'll return None if the
                 event's room does not match.
             event_id: The event ID to obtain.
@@ -173,8 +173,11 @@ class EventHandler:
         if not event:
             return None
 
-        users = await self.store.get_users_in_room(event.room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=event.room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
             self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3b5eaf5156..dd4b9f66d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,6 +32,7 @@ from typing import (
 )
 
 import attr
+from prometheus_client import Histogram
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -59,6 +60,7 @@ from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import NOT_SPAM
 from synapse.replication.http.federation import (
@@ -68,7 +70,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -78,36 +80,28 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-    """Get joined domains from state
-
-    Args:
-        state: State map from type/state key to event.
-
-    Returns:
-        Returns a list of servers with the lowest depth of their joins.
-            Sorted by lowest depth first.
-    """
-    joined_users = [
-        (state_key, int(event.depth))
-        for (e_type, state_key), event in state.items()
-        if e_type == EventTypes.Member and event.membership == Membership.JOIN
-    ]
-
-    joined_domains: Dict[str, int] = {}
-    for u, d in joined_users:
-        try:
-            dom = get_domain_from_id(u)
-            old_d = joined_domains.get(dom)
-            if old_d:
-                joined_domains[dom] = min(d, old_d)
-            else:
-                joined_domains[dom] = d
-        except Exception:
-            pass
-
-    return sorted(joined_domains.items(), key=lambda d: d[1])
+# Added to debug performance and track progress on optimizations
+backfill_processing_before_timer = Histogram(
+    "synapse_federation_backfill_processing_before_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        30.0,
+        40.0,
+        60.0,
+        80.0,
+        "+Inf",
+    ),
+)
 
 
 class _BackfillPointType(Enum):
@@ -137,6 +131,7 @@ class FederationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
+        self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
@@ -180,6 +175,7 @@ class FederationHandler:
                 "resume_sync_partial_state_room", self._resume_sync_partial_state_room
             )
 
+    @trace
     async def maybe_backfill(
         self, room_id: str, current_depth: int, limit: int
     ) -> bool:
@@ -195,12 +191,39 @@ class FederationHandler:
                 return. This is used as part of the heuristic to decide if we
                 should back paginate.
         """
+        # Starting the processing time here so we can include the room backfill
+        # linearizer lock queue in the timing
+        processing_start_time = self.clock.time_msec()
+
         async with self._room_backfill.queue(room_id):
-            return await self._maybe_backfill_inner(room_id, current_depth, limit)
+            return await self._maybe_backfill_inner(
+                room_id,
+                current_depth,
+                limit,
+                processing_start_time=processing_start_time,
+            )
 
     async def _maybe_backfill_inner(
-        self, room_id: str, current_depth: int, limit: int
+        self,
+        room_id: str,
+        current_depth: int,
+        limit: int,
+        *,
+        processing_start_time: int,
     ) -> bool:
+        """
+        Checks whether the `current_depth` is at or approaching any backfill
+        points in the room and if so, will backfill. We only care about
+        checking backfill points that happened before the `current_depth`
+        (meaning less than or equal to the `current_depth`).
+
+        Args:
+            room_id: The room to backfill in.
+            current_depth: The depth to check at for any upcoming backfill points.
+            limit: The max number of events to request from the remote federated server.
+            processing_start_time: The time when `maybe_backfill` started
+                processing. Only used for timing.
+        """
         backwards_extremities = [
             _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
             for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
@@ -368,23 +391,29 @@ class FederationHandler:
         logger.debug(
             "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request",
+            str(extremities_to_request),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+            str(len(extremities_to_request)),
+        )
 
         # Now we need to decide which hosts to hit first.
-
-        # First we try hosts that are already in the room
+        # First we try hosts that are already in the room.
         # TODO: HEURISTIC ALERT.
+        likely_domains = (
+            await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+        )
 
-        curr_state = await self._storage_controllers.state.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
-
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
+                # We don't want to ask our own server for information we don't have
+                if dom == self.server_name:
+                    continue
+
                 try:
                     await self._federation_event_handler.backfill(
                         dom, room_id, limit=100, extremities=extremities_to_request
@@ -423,6 +452,11 @@ class FederationHandler:
 
             return False
 
+        processing_end_time = self.clock.time_msec()
+        backfill_processing_before_timer.observe(
+            (processing_end_time - processing_start_time) / 1000
+        )
+
         success = await try_backfill(likely_domains)
         if success:
             return True
@@ -546,9 +580,9 @@ class FederationHandler:
             )
 
             if ret.partial_state:
-                # TODO(faster_joins): roll this back if we don't manage to start the
-                #   background resync (eg process_remote_join fails)
-                #   https://github.com/matrix-org/synapse/issues/12998
+                # Mark the room as having partial state.
+                # The background process is responsible for unmarking this flag,
+                # even if the join fails.
                 await self.store.store_partial_state_room(room_id, ret.servers_in_room)
 
             try:
@@ -574,17 +608,21 @@ class FederationHandler:
                     room_id,
                 )
                 raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
-
-            if ret.partial_state:
-                # Kick off the process of asynchronously fetching the state for this
-                # room.
-                run_as_background_process(
-                    desc="sync_partial_state_room",
-                    func=self._sync_partial_state_room,
-                    initial_destination=origin,
-                    other_destinations=ret.servers_in_room,
-                    room_id=room_id,
-                )
+            finally:
+                # Always kick off the background process that asynchronously fetches
+                # state for the room.
+                # If the join failed, the background process is responsible for
+                # cleaning up — including unmarking the room as a partial state room.
+                if ret.partial_state:
+                    # Kick off the process of asynchronously fetching the state for this
+                    # room.
+                    run_as_background_process(
+                        desc="sync_partial_state_room",
+                        func=self._sync_partial_state_room,
+                        initial_destination=origin,
+                        other_destinations=ret.servers_in_room,
+                        room_id=room_id,
+                    )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
@@ -748,6 +786,23 @@ class FederationHandler:
         # (and return a 404 otherwise)
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /make_join, so return a 404 as we would if we weren't in the
+            # room at all.
+            # The main reason we can't respond properly is that we need to know about
+            # the auth events for the join event that we would return.
+            # We also should not bother entertaining the /make_join since we cannot
+            # handle the /send_join.
+            logger.info(
+                "Rejecting /make_join to %s because it's a partial state room", room_id
+            )
+            raise SynapseError(
+                404,
+                "Unable to handle /make_join right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
             room_id, self.server_name
@@ -1058,6 +1113,8 @@ class FederationHandler:
 
         return event
 
+    @trace
+    @tag_args
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1539,15 +1596,16 @@ class FederationHandler:
 
         # Make an infinite iterator of destinations to try. Once we find a working
         # destination, we'll stick with it until it flakes.
+        destinations: Collection[str]
         if initial_destination is not None:
             # Move `initial_destination` to the front of the list.
             destinations = list(other_destinations)
             if initial_destination in destinations:
                 destinations.remove(initial_destination)
             destinations = [initial_destination] + destinations
-            destination_iter = itertools.cycle(destinations)
         else:
-            destination_iter = itertools.cycle(other_destinations)
+            destinations = other_destinations
+        destination_iter = itertools.cycle(destinations)
 
         # `destination` is the current remote homeserver we're pulling from.
         destination = next(destination_iter)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9ce5bea0ed..100785ebba 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -28,7 +29,7 @@ from typing import (
     Tuple,
 )
 
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -58,6 +59,13 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import (
+    SynapseTags,
+    set_tag,
+    start_active_span,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
@@ -90,6 +98,36 @@ soft_failed_event_counter = Counter(
     "Events received over federation that we marked as soft_failed",
 )
 
+# Added to debug performance and track progress on optimizations
+backfill_processing_after_timer = Histogram(
+    "synapse_federation_backfill_processing_after_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.25,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        25.0,
+        30.0,
+        40.0,
+        50.0,
+        60.0,
+        80.0,
+        100.0,
+        120.0,
+        150.0,
+        180.0,
+        "+Inf",
+    ),
+)
+
 
 class FederationEventHandler:
     """Handles events that originated from federation.
@@ -277,7 +315,8 @@ class FederationEventHandler:
                 )
 
         try:
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
         except PartialStateConflictError:
             # The room was un-partial stated while we were processing the PDU.
             # Try once more, with full state this time.
@@ -285,7 +324,8 @@ class FederationEventHandler:
                 "Room %s was un-partial stated while processing the PDU, trying again.",
                 room_id,
             )
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -315,6 +355,7 @@ class FederationEventHandler:
             The event and context of the event after inserting it into the room graph.
 
         Raises:
+            RuntimeError if any prev_events are missing
             SynapseError if the event is not accepted into the room
             PartialStateConflictError if the room was un-partial stated in between
                 computing the state at the event and persisting it. The caller should
@@ -347,7 +388,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -375,7 +416,7 @@ class FederationEventHandler:
         # need to.
         await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
 
-        await self._check_for_soft_fail(event, None, origin=origin)
+        await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
         return event, context
 
@@ -405,6 +446,7 @@ class FederationEventHandler:
             prev_member_event,
         )
 
+    @trace
     async def process_remote_join(
         self,
         origin: str,
@@ -485,7 +527,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -533,32 +575,36 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
-            # build a new state group for it if need be
-            context = await self._state_handler.compute_event_context(
-                event,
-                state_ids_before_event=state_ids,
+            context = await self._compute_event_context_with_maybe_missing_prevs(
+                destination, event
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
-                # partial state - ie, an event has an earlier stream_ordering than one
-                # or more of its prev_events, so we de-partial-state it before its
-                # prev_events.
+                # partial state. We were careful to only pick events from the db without
+                # partial-state prev events, so that implies that a prev event has
+                # been persisted (with partial state) since we did the query.
                 #
-                # TODO(faster_joins): we probably need to be more intelligent, and
-                #    exclude partial-state prev_events from consideration
-                #    https://github.com/matrix-org/synapse/issues/13001
+                # So, let's just ignore `event` for now; when we re-run the db query
+                # we should instead get its partial-state prev event, which we will
+                # de-partial-state, and then come back to event.
                 logger.warning(
-                    "%s still has partial state: can't de-partial-state it yet",
+                    "%s still has prev_events with partial state: can't de-partial-state it yet",
                     event.event_id,
                 )
                 return
+
+            # since the state at this event has changed, we should now re-evaluate
+            # whether it should have been rejected. We must already have all of the
+            # auth events (from last time we went round this path), so there is no
+            # need to pass the origin.
+            await self._check_event_auth(None, event, context)
+
             await self._store.update_state_for_partial_state_event(event, context)
             self._state_storage_controller.notify_event_un_partial_stated(
                 event.event_id
             )
 
+    @trace
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> None:
@@ -588,21 +634,23 @@ class FederationEventHandler:
         if not events:
             return
 
-        # if there are any events in the wrong room, the remote server is buggy and
-        # should not be trusted.
-        for ev in events:
-            if ev.room_id != room_id:
-                raise InvalidResponseError(
-                    f"Remote server {dest} returned event {ev.event_id} which is in "
-                    f"room {ev.room_id}, when we were backfilling in {room_id}"
-                )
+        with backfill_processing_after_timer.time():
+            # if there are any events in the wrong room, the remote server is buggy and
+            # should not be trusted.
+            for ev in events:
+                if ev.room_id != room_id:
+                    raise InvalidResponseError(
+                        f"Remote server {dest} returned event {ev.event_id} which is in "
+                        f"room {ev.room_id}, when we were backfilling in {room_id}"
+                    )
 
-        await self._process_pulled_events(
-            dest,
-            events,
-            backfilled=True,
-        )
+            await self._process_pulled_events(
+                dest,
+                events,
+                backfilled=True,
+            )
 
+    @trace
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
     ) -> None:
@@ -703,8 +751,9 @@ class FederationEventHandler:
         logger.info("Got %d prev_events", len(missing_events))
         await self._process_pulled_events(origin, missing_events, backfilled=False)
 
+    @trace
     async def _process_pulled_events(
-        self, origin: str, events: Iterable[EventBase], backfilled: bool
+        self, origin: str, events: Collection[EventBase], backfilled: bool
     ) -> None:
         """Process a batch of events we have pulled from a remote server
 
@@ -719,6 +768,15 @@ class FederationEventHandler:
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str([event.event_id for event in events]),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(events)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
         logger.debug(
             "processing pulled backfilled=%s events=%s",
             backfilled,
@@ -741,6 +799,8 @@ class FederationEventHandler:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
 
+    @trace
+    @tag_args
     async def _process_pulled_event(
         self, origin: str, event: EventBase, backfilled: bool
     ) -> None:
@@ -765,10 +825,24 @@ class FederationEventHandler:
         """
         logger.info("Processing pulled event %s", event)
 
-        # these should not be outliers.
-        assert (
-            not event.internal_metadata.is_outlier()
-        ), "pulled event unexpectedly flagged as outlier"
+        # This function should not be used to persist outliers (use something
+        # else) because this does a bunch of operations that aren't necessary
+        # (extra work; in particular, it makes sure we have all the prev_events
+        # and resolves the state across those prev events). If you happen to run
+        # into a situation where the event you're trying to process/backfill is
+        # marked as an `outlier`, then you should update that spot to return an
+        # `EventBase` copy that doesn't have `outlier` flag set.
+        #
+        # `EventBase` is used to represent both an event we have not yet
+        # persisted, and one that we have persisted and now keep in the cache.
+        # In an ideal world this method would only be called with the first type
+        # of event, but it turns out that's not actually the case and for
+        # example, you could get an event from cache that is marked as an
+        # `outlier` (fix up that spot though).
+        assert not event.internal_metadata.is_outlier(), (
+            "Outlier event passed to _process_pulled_event. "
+            "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`."
+        )
 
         event_id = event.event_id
 
@@ -778,7 +852,7 @@ class FederationEventHandler:
         if existing:
             if not existing.internal_metadata.is_outlier():
                 logger.info(
-                    "Ignoring received event %s which we have already seen",
+                    "_process_pulled_event: Ignoring received event %s which we have already seen",
                     event_id,
                 )
                 return
@@ -788,32 +862,66 @@ class FederationEventHandler:
             self._sanity_check_event(event)
         except SynapseError as err:
             logger.warning("Event %s failed sanity check: %s", event_id, err)
+            await self._store.record_event_failed_pull_attempt(
+                event.room_id, event_id, str(err)
+            )
             return
 
         try:
-            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
-            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
-            #   not return partial state
-            #   https://github.com/matrix-org/synapse/issues/13002
+            try:
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
+            except PartialStateConflictError:
+                # The room was un-partial stated while we were processing the event.
+                # Try once more, with full state this time.
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
 
-            await self._process_received_pdu(
-                origin, event, state_ids=state_ids, backfilled=backfilled
-            )
+                # We ought to have full state now, barring some unlikely race where we left and
+                # rejoned the room in the background.
+                if context.partial_state:
+                    raise AssertionError(
+                        f"Event {event.event_id} still has a partial resolved state "
+                        f"after room {event.room_id} was un-partial stated"
+                    )
+
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
         except FederationError as e:
+            await self._store.record_event_failed_pull_attempt(
+                event.room_id, event_id, str(e)
+            )
+
             if e.code == 403:
                 logger.warning("Pulled event %s failed history check.", event_id)
             else:
                 raise
 
-    async def _resolve_state_at_missing_prevs(
+    @trace
+    async def _compute_event_context_with_maybe_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[StateMap[str]]:
-        """Calculate the state at an event with missing prev_events.
+    ) -> EventContext:
+        """Build an EventContext structure for a non-outlier event whose prev_events may
+        be missing.
 
-        This is used when we have pulled a batch of events from a remote server, and
-        still don't have all the prev_events.
+        This is used when we have pulled a batch of events from a remote server, and may
+        not have all the prev_events.
 
-        If we already have all the prev_events for `event`, this method does nothing.
+        To build an EventContext, we need to calculate the state before the event. If we
+        already have all the prev_events for `event`, we can simply use the state after
+        the prev_events to calculate the state before `event`.
 
         Otherwise, the missing prevs become new backwards extremities, and we fall back
         to asking the remote server for the state after each missing `prev_event`,
@@ -834,8 +942,7 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns
-            the event ids of the state at `event`.
+            The event context.
 
         Raises:
             FederationError if we fail to get the state from the remote server after any
@@ -849,7 +956,7 @@ class FederationEventHandler:
         missing_prevs = prevs - seen
 
         if not missing_prevs:
-            return None
+            return await self._state_handler.compute_event_context(event)
 
         logger.info(
             "Event %s is missing prev_events %s: calculating state for a "
@@ -861,9 +968,15 @@ class FederationEventHandler:
         # resolve them to find the correct state at the current event.
 
         try:
+            # Determine whether we may be about to retrieve partial state
+            # Events may be un-partial stated right after we compute the partial state
+            # flag, but that's okay, as long as the flag errs on the conservative side.
+            partial_state_flags = await self._store.get_partial_state_events(seen)
+            partial_state = any(partial_state_flags.values())
+
             # Get the state of the events we know about
             ours = await self._state_storage_controller.get_state_groups_ids(
-                room_id, seen
+                room_id, seen, await_full_state=False
             )
 
             # state_maps is a list of mappings from (type, state_key) to event_id
@@ -909,8 +1022,12 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state_map
+        return await self._state_handler.compute_event_context(
+            event, state_ids_before_event=state_map, partial_state=partial_state
+        )
 
+    @trace
+    @tag_args
     async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
@@ -931,6 +1048,14 @@ class FederationEventHandler:
             InvalidResponseError: if the remote homeserver's response contains fields
                 of the wrong type.
         """
+
+        # It would be better if we could query the difference from our known
+        # state to the given `event_id` so the sending server doesn't have to
+        # send as much and we don't have to process as many events. For example
+        # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k
+        # auth_events) from this call.
+        #
+        # Tracked by https://github.com/matrix-org/synapse/issues/13618
         (
             state_event_ids,
             auth_event_ids,
@@ -955,11 +1080,11 @@ class FederationEventHandler:
         )
         have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - have_events
+        missing_desired_event_ids = desired_events - have_events
         logger.debug(
             "_get_state_ids_after_missing_prev_event(event_id=%s): We are missing %i events (got %i)",
             event_id,
-            len(missing_desired_events),
+            len(missing_desired_event_ids),
             len(have_events),
         )
 
@@ -971,17 +1096,34 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - have_events
-        missing_auth_events.difference_update(
-            await self._store.have_seen_events(room_id, missing_auth_events)
+        missing_auth_event_ids = set(auth_event_ids) - have_events
+        missing_auth_event_ids.difference_update(
+            await self._store.have_seen_events(room_id, missing_auth_event_ids)
         )
         logger.debug(
             "_get_state_ids_after_missing_prev_event(event_id=%s): We are also missing %i auth events",
             event_id,
-            len(missing_auth_events),
+            len(missing_auth_event_ids),
         )
 
-        missing_events = missing_desired_events | missing_auth_events
+        missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
+
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+            str(missing_auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+            str(len(missing_auth_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+            str(missing_desired_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+            str(len(missing_desired_event_ids)),
+        )
 
         # Making an individual request for each of 1000s of events has a lot of
         # overhead. On the other hand, we don't really want to fetch all of the events
@@ -992,7 +1134,7 @@ class FederationEventHandler:
         #
         # TODO: might it be better to have an API which lets us do an aggregate event
         #   request
-        if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+        if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
             logger.debug(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Requesting complete state from remote",
                 event_id,
@@ -1002,10 +1144,10 @@ class FederationEventHandler:
             logger.debug(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Fetching %i events from remote",
                 event_id,
-                len(missing_events),
+                len(missing_event_ids),
             )
             await self._get_events_and_persist(
-                destination=destination, room_id=room_id, event_ids=missing_events
+                destination=destination, room_id=room_id, event_ids=missing_event_ids
             )
 
         # We now need to fill out the state map, which involves fetching the
@@ -1063,12 +1205,24 @@ class FederationEventHandler:
         # missing state at that event is a warning, not a blocker
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
+        failed_to_fetch = desired_events - event_metadata.keys()
+        # `event_id` could be missing from `event_metadata` because it's not necessarily
+        # a state event. We've already checked that we've fetched it above.
+        failed_to_fetch.discard(event_id)
         if failed_to_fetch:
             logger.warning(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Failed to fetch missing state events %s",
                 event_id,
                 failed_to_fetch,
             )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+                str(failed_to_fetch),
+            )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+                str(len(failed_to_fetch)),
+            )
 
         if remote_event.is_state() and remote_event.rejected_reason is None:
             state_map[
@@ -1077,6 +1231,8 @@ class FederationEventHandler:
 
         return state_map
 
+    @trace
+    @tag_args
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1098,11 +1254,12 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=(event_id,)
             )
 
+    @trace
     async def _process_received_pdu(
         self,
         origin: str,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1124,30 +1281,20 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state_ids: Normally None, but if we are handling a gap in the graph
-                (ie, we are missing one or more prev_events), the resolved state at the
-                event. Must not be partial state.
+            context: The `EventContext` to persist the event with.
 
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
 
         PartialStateConflictError: if the room was un-partial stated in between
-            computing the state at the event and persisting it. The caller should retry
-            exactly once in this case. Will never be raised if `state_ids` is provided.
+            computing the state at the event and persisting it. The caller should
+            recompute `context` and retry exactly once when this happens.
         """
         logger.debug("Processing event: %s", event)
         assert not event.internal_metadata.outlier
 
-        context = await self._state_handler.compute_event_context(
-            event,
-            state_ids_before_event=state_ids,
-        )
         try:
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
             # This happens only if we couldn't find the auth events. We'll already have
             # logged a warning, so now we just convert to a FederationError.
@@ -1157,7 +1304,7 @@ class FederationEventHandler:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state_ids, origin=origin)
+            await self._check_for_soft_fail(event, context=context, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1258,6 +1405,7 @@ class FederationEventHandler:
         except Exception:
             logger.exception("Failed to resync device for %s", sender)
 
+    @trace
     async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
         """Handles backfilling the insertion event when we receive a marker
         event that points to one.
@@ -1289,7 +1437,7 @@ class FederationEventHandler:
         logger.debug("_handle_marker_event: received %s", marker_event)
 
         insertion_event_id = marker_event.content.get(
-            EventContentFields.MSC2716_MARKER_INSERTION
+            EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
         )
 
         if insertion_event_id is None:
@@ -1342,6 +1490,55 @@ class FederationEventHandler:
             marker_event,
         )
 
+    async def backfill_event_id(
+        self, destination: str, room_id: str, event_id: str
+    ) -> EventBase:
+        """Backfill a single event and persist it as a non-outlier which means
+        we also pull in all of the state and auth events necessary for it.
+
+        Args:
+            destination: The homeserver to pull the given event_id from.
+            room_id: The room where the event is from.
+            event_id: The event ID to backfill.
+
+        Raises:
+            FederationError if we are unable to find the event from the destination
+        """
+        logger.info(
+            "backfill_event_id: event_id=%s from destination=%s", event_id, destination
+        )
+
+        room_version = await self._store.get_room_version(room_id)
+
+        event_from_response = await self._federation_client.get_pdu(
+            [destination],
+            event_id,
+            room_version,
+        )
+
+        if not event_from_response:
+            raise FederationError(
+                "ERROR",
+                404,
+                "Unable to find event_id=%s from destination=%s to backfill."
+                % (event_id, destination),
+                affected=event_id,
+            )
+
+        # Persist the event we just fetched, including pulling all of the state
+        # and auth events to de-outlier it. This also sets up the necessary
+        # `state_groups` for the event.
+        await self._process_pulled_events(
+            destination,
+            [event_from_response],
+            # Prevent notifications going to clients
+            backfilled=True,
+        )
+
+        return event_from_response
+
+    @trace
+    @tag_args
     async def _get_events_and_persist(
         self, destination: str, room_id: str, event_ids: Collection[str]
     ) -> None:
@@ -1387,6 +1584,7 @@ class FederationEventHandler:
         logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
         await self._auth_and_persist_outliers(room_id, events)
 
+    @trace
     async def _auth_and_persist_outliers(
         self, room_id: str, events: Iterable[EventBase]
     ) -> None:
@@ -1405,6 +1603,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        event_ids = event_map.keys()
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         # filter out any events we have already seen. This might happen because
         # the events were eagerly pushed to us (eg, during a room join), or because
         # another thread has raced against us since we decided to request the event.
@@ -1521,24 +1729,21 @@ class FederationEventHandler:
             backfilled=True,
         )
 
+    @trace
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: Optional[str], event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin: The host the event originates from.
+            origin: The host the event originates from. This is used to fetch
+               any missing auth events. It can be set to None, but only if we are
+               sure that we already have all the auth events.
             event: The event itself.
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1553,7 +1758,7 @@ class FederationEventHandler:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1563,8 +1768,19 @@ class FederationEventHandler:
         claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
             origin, event
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+            str([ev.event_id for ev in claimed_auth_events]),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+            str(len(claimed_auth_events)),
+        )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
             await check_state_independent_auth_rules(self._store, event)
             check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1573,55 +1789,91 @@ class FederationEventHandler:
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
+                e,
             )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_state_dependent_auth_rules(event, auth_events_for_auth.values())
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
+    @trace
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1639,17 +1891,27 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
         such.
 
+        Does nothing for events in rooms with partial state, since we may not have an
+        accurate membership event for the sender in the current state.
+
         Args:
             event
-            state_ids: The state at the event if we don't have all the event's prev events
+            context: The `EventContext` which we are about to persist the event with.
             origin: The host the event originates from.
         """
+        if await self._store.is_partial_state_room(event.room_id):
+            # We might not know the sender's membership in the current state, so don't
+            # soft fail anything. Even if we do have a membership for the sender in the
+            # current state, it may have been derived from state resolution between
+            # partial and full state and may not be accurate.
+            return
+
         extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
         extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
@@ -1666,11 +1928,15 @@ class FederationEventHandler:
         auth_types = auth_types_for_event(room_version_obj, event)
 
         # Calculate the "current state".
-        if state_ids is not None:
-            # If we're explicitly given the state then we won't have all the
-            # prev events, and so we have a gap in the graph. In this case
-            # we want to be a little careful as we might have been down for
-            # a while and have an incorrect view of the current state,
+        seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+        has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+        if has_missing_prevs:
+            # We don't have all the prev_events of this event, which means we have a
+            # gap in the graph, and the new event is going to become a new backwards
+            # extremity.
+            #
+            # In this case we want to be a little careful as we might have been
+            # down for a while and have an incorrect view of the current state,
             # however we still want to do checks as gaps are easy to
             # maliciously manufacture.
             #
@@ -1683,6 +1949,7 @@ class FederationEventHandler:
                 event.room_id, extrem_ids
             )
             state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_ids = await context.get_prev_state_ids()
             state_sets.append(state_ids)
             current_state_ids = (
                 await self._state_resolution_handler.resolve_events_with_store(
@@ -1731,95 +1998,8 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
-        self, destination: str, event: EventBase
+        self, destination: Optional[str], event: EventBase
     ) -> Collection[EventBase]:
         """Fetch this event's auth_events, from database or remote
 
@@ -1835,12 +2015,19 @@ class FederationEventHandler:
         Args:
             destination: where to send the /event_auth request. Typically the server
                that sent us `event` in the first place.
+
+               If this is None, no attempt is made to load any missing auth events:
+               rather, an AssertionError is raised if there are any missing events.
+
             event: the event whose auth_events we want
 
         Returns:
             all of the events listed in `event.auth_events_ids`, after deduplication
 
         Raises:
+            AssertionError if some auth events were missing and no `destination` was
+            supplied.
+
             AuthError if we were unable to fetch the auth_events for any reason.
         """
         event_auth_event_ids = set(event.auth_event_ids())
@@ -1852,6 +2039,13 @@ class FederationEventHandler:
         )
         if not missing_auth_event_ids:
             return event_auth_events.values()
+        if destination is None:
+            # this shouldn't happen: destination must be set unless we know we have already
+            # persisted the auth events.
+            raise AssertionError(
+                "_load_or_fetch_auth_events_for_event() called with no destination for "
+                "an event with missing auth_events"
+            )
 
         logger.info(
             "Event %s refers to unknown auth events %s: fetching auth chain",
@@ -1887,6 +2081,8 @@ class FederationEventHandler:
         # instead we raise an AuthError, which will make the caller ignore it.
         raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
 
+    @trace
+    @tag_args
     async def _get_remote_auth_chain_for_event(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1915,61 +2111,7 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
+    @trace
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
@@ -2078,8 +2220,17 @@ class FederationEventHandler:
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event in events:
-                    await self._notify_persisted_event(event, max_stream_token)
+                with start_active_span("notify_persisted_events"):
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids",
+                        str([ev.event_id for ev in events]),
+                    )
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids.length",
+                        str(len(events)),
+                    )
+                    for event in events:
+                        await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
 
@@ -2120,6 +2271,10 @@ class FederationEventHandler:
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            # TODO retrieve the previous state, and exclude join -> join transitions
+            self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9bca2bc4b2..93d09e9939 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,7 +26,6 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.site import SynapseRequest
@@ -163,8 +162,7 @@ class IdentityHandler:
         sid: str,
         mxid: str,
         id_server: str,
-        id_access_token: Optional[str] = None,
-        use_v2: bool = True,
+        id_access_token: str,
     ) -> JsonDict:
         """Bind a 3PID to an identity server
 
@@ -174,8 +172,7 @@ class IdentityHandler:
             mxid: The MXID to bind the 3PID to
             id_server: The domain of the identity server to query
             id_access_token: The access token to authenticate to the identity
-                server with, if necessary. Required if use_v2 is true
-            use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+                server with
 
         Raises:
             SynapseError: On any of the following conditions
@@ -187,24 +184,15 @@ class IdentityHandler:
         """
         logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
 
-        # If an id_access_token is not supplied, force usage of v1
-        if id_access_token is None:
-            use_v2 = False
-
         if not valid_id_server_location(id_server):
             raise SynapseError(
                 400,
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        # Decide which API endpoint URLs to use
-        headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
-        if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
-            headers["Authorization"] = create_id_access_token_header(id_access_token)  # type: ignore
-        else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+        bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+        headers = {"Authorization": create_id_access_token_header(id_access_token)}
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -223,21 +211,14 @@ class IdentityHandler:
 
             return data
         except HttpResponseException as e:
-            if e.code != 404 or not use_v2:
-                logger.error("3PID bind failed with Matrix error: %r", e)
-                raise e.to_synapse_error()
+            logger.error("3PID bind failed with Matrix error: %r", e)
+            raise e.to_synapse_error()
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-        logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
-        res = await self.bind_threepid(
-            client_secret, sid, mxid, id_server, id_access_token, use_v2=False
-        )
-        return res
-
     async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
         """Attempt to remove a 3PID from an identity server, or if one is not provided, all
         identity servers we're aware the binding is present on
@@ -300,8 +281,8 @@ class IdentityHandler:
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
+        url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
+        url_bytes = b"/_matrix/identity/v2/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -434,48 +415,6 @@ class IdentityHandler:
 
         return session_id
 
-    async def requestEmailToken(
-        self,
-        id_server: str,
-        email: str,
-        client_secret: str,
-        send_attempt: int,
-        next_link: Optional[str] = None,
-    ) -> JsonDict:
-        """
-        Request an external server send an email on our behalf for the purposes of threepid
-        validation.
-
-        Args:
-            id_server: The identity server to proxy to
-            email: The email to send the message to
-            client_secret: The unique client_secret sends by the user
-            send_attempt: Which attempt this is
-            next_link: A link to redirect the user to once they submit the token
-
-        Returns:
-            The json response body from the server
-        """
-        params = {
-            "email": email,
-            "client_secret": client_secret,
-            "send_attempt": send_attempt,
-        }
-        if next_link:
-            params["next_link"] = next_link
-
-        try:
-            data = await self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
-                params,
-            )
-            return data
-        except HttpResponseException as e:
-            logger.info("Proxied requestToken failed: %r", e)
-            raise e.to_synapse_error()
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-
     async def requestMsisdnToken(
         self,
         id_server: str,
@@ -549,18 +488,7 @@ class IdentityHandler:
         validation_session = None
 
         # Try to validate as email
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            # Remote emails will only be used if a valid identity server is provided.
-            assert (
-                self.hs.config.registration.account_threepid_delegate_email is not None
-            )
-
-            # Ask our delegated email identity server
-            validation_session = await self.threepid_from_creds(
-                self.hs.config.registration.account_threepid_delegate_email,
-                threepid_creds,
-            )
-        elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             # Get a validated session matching these details
             validation_session = await self.store.get_threepid_validation_session(
                 "email", client_secret, sid=sid, validated=True
@@ -610,11 +538,7 @@ class IdentityHandler:
             raise SynapseError(400, "Error contacting the identity server")
 
     async def lookup_3pid(
-        self,
-        id_server: str,
-        medium: str,
-        address: str,
-        id_access_token: Optional[str] = None,
+        self, id_server: str, medium: str, address: str, id_access_token: str
     ) -> Optional[str]:
         """Looks up a 3pid in the passed identity server.
 
@@ -629,60 +553,15 @@ class IdentityHandler:
         Returns:
             the matrix ID of the 3pid, or None if it is not recognized.
         """
-        if id_access_token is not None:
-            try:
-                results = await self._lookup_3pid_v2(
-                    id_server, id_access_token, medium, address
-                )
-                return results
-
-            except Exception as e:
-                # Catch HttpResponseExcept for a non-200 response code
-                # Check if this identity server does not know about v2 lookups
-                if isinstance(e, HttpResponseException) and e.code == 404:
-                    # This is an old identity server that does not yet support v2 lookups
-                    logger.warning(
-                        "Attempted v2 lookup on v1 identity server %s. Falling "
-                        "back to v1",
-                        id_server,
-                    )
-                else:
-                    logger.warning("Error when looking up hashing details: %s", e)
-                    return None
-
-        return await self._lookup_3pid_v1(id_server, medium, address)
-
-    async def _lookup_3pid_v1(
-        self, id_server: str, medium: str, address: str
-    ) -> Optional[str]:
-        """Looks up a 3pid in the passed identity server using v1 lookup.
-
-        Args:
-            id_server: The server name (including port, if required)
-                of the identity server to use.
-            medium: The type of the third party identifier (e.g. "email").
-            address: The third party identifier (e.g. "foo@example.com").
 
-        Returns:
-            the matrix ID of the 3pid, or None if it is not recognized.
-        """
         try:
-            data = await self.blacklisting_http_client.get_json(
-                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
-                {"medium": medium, "address": address},
+            results = await self._lookup_3pid_v2(
+                id_server, id_access_token, medium, address
             )
-
-            if "mxid" in data:
-                # note: we used to verify the identity server's signature here, but no longer
-                # require or validate it. See the following for context:
-                # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
-                return data["mxid"]
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-        except OSError as e:
-            logger.warning("Error from v1 identity server lookup: %s" % (e,))
-
-        return None
+            return results
+        except Exception as e:
+            logger.warning("Error when looking up hashing details: %s", e)
+            return None
 
     async def _lookup_3pid_v2(
         self, id_server: str, id_access_token: str, medium: str, address: str
@@ -811,7 +690,7 @@ class IdentityHandler:
         room_type: Optional[str],
         inviter_display_name: str,
         inviter_avatar_url: str,
-        id_access_token: Optional[str] = None,
+        id_access_token: str,
     ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
         """
         Asks an identity server for a third party invite.
@@ -832,7 +711,7 @@ class IdentityHandler:
             inviter_display_name: The current display name of the
                 inviter.
             inviter_avatar_url: The URL of the inviter's avatar.
-            id_access_token (str|None): The access token to authenticate to the identity
+            id_access_token (str): The access token to authenticate to the identity
                 server with
 
         Returns:
@@ -864,71 +743,24 @@ class IdentityHandler:
             invite_config["org.matrix.web_client_location"] = self._web_client_location
 
         # Add the identity service access token to the JSON body and use the v2
-        # Identity Service endpoints if id_access_token is present
+        # Identity Service endpoints
         data = None
-        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
 
-        if id_access_token:
-            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
-            )
+        key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
+            id_server_scheme,
+            id_server,
+        )
 
-            # Attempt a v2 lookup
-            url = base_url + "/v2/store-invite"
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url,
-                    invite_config,
-                    {"Authorization": create_id_access_token_header(id_access_token)},
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                if e.code != 404:
-                    logger.info("Failed to POST %s with JSON: %s", url, e)
-                    raise e
-
-        if data is None:
-            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+        url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
+        try:
+            data = await self.blacklisting_http_client.post_json_get_json(
+                url,
+                invite_config,
+                {"Authorization": create_id_access_token_header(id_access_token)},
             )
-            url = base_url + "/api/v1/store-invite"
-
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url, invite_config
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                logger.warning(
-                    "Error trying to call /store-invite on %s%s: %s",
-                    id_server_scheme,
-                    id_server,
-                    e,
-                )
-
-            if data is None:
-                # Some identity servers may only support application/x-www-form-urlencoded
-                # types. This is especially true with old instances of Sydent, see
-                # https://github.com/matrix-org/sydent/pull/170
-                try:
-                    data = await self.blacklisting_http_client.post_urlencoded_get_json(
-                        url, invite_config
-                    )
-                except HttpResponseException as e:
-                    logger.warning(
-                        "Error calling /store-invite on %s%s with fallback "
-                        "encoding: %s",
-                        id_server_scheme,
-                        id_server,
-                        e,
-                    )
-                    raise e
-
-        # TODO: Check for success
+        except RequestTimedOutError:
+            raise SynapseError(500, "Timed out contacting identity server")
+
         token = data["token"]
         public_keys = data.get("public_keys", [])
         if "public_key" in data:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 85b472f250..860c82c110 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -143,8 +143,8 @@ class InitialSyncHandler:
             joined_rooms,
             to_key=int(now_token.receipt_key),
         )
-        if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
+
+        receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -309,18 +309,18 @@ class InitialSyncHandler:
         if blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
-        user_id = requester.user.to_string()
-
         (
             membership,
             member_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
             room_id,
-            user_id,
+            requester,
             allow_departed_users=True,
         )
         is_peeking = member_event_id is None
 
+        user_id = requester.user.to_string()
+
         if membership == Membership.JOIN:
             result = await self._room_initial_sync_joined(
                 user_id, room_id, pagin_config, membership, is_peeking
@@ -456,11 +456,8 @@ class InitialSyncHandler:
             )
             if not receipts:
                 return []
-            if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private_receipts(
-                    receipts, user_id
-                )
-            return receipts
+
+            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1980e37dae..e07cda133a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -41,6 +41,7 @@ from synapse.api.errors import (
     NotFoundError,
     ShadowBanError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -51,6 +52,7 @@ from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
+from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -102,7 +104,7 @@ class MessageHandler:
 
     async def get_room_data(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         event_type: str,
         state_key: str,
@@ -110,7 +112,7 @@ class MessageHandler:
         """Get data from a room.
 
         Args:
-            user_id
+            requester: The user who did the request.
             room_id
             event_type
             state_key
@@ -123,7 +125,7 @@ class MessageHandler:
             membership,
             membership_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         if membership == Membership.JOIN:
@@ -149,17 +151,20 @@ class MessageHandler:
                 "Attempted to retrieve data from a room for a user that has never been in it. "
                 "This should not have happened."
             )
-            raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
+            raise UnstableSpecAuthError(
+                403,
+                "User not in room",
+                errcode=Codes.NOT_JOINED,
+            )
 
         return data
 
     async def get_state_events(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         state_filter: Optional[StateFilter] = None,
         at_token: Optional[StreamToken] = None,
-        is_guest: bool = False,
     ) -> List[dict]:
         """Retrieve all state events for a given room. If the user is
         joined to the room then return the current state. If the user has
@@ -168,14 +173,13 @@ class MessageHandler:
         visible.
 
         Args:
-            user_id: The user requesting state events.
+            requester: The user requesting state events.
             room_id: The room ID to get all state events from.
             state_filter: The state filter used to fetch state from the database.
             at_token: the stream token of the at which we are requesting
                 the stats. If the user is not allowed to view the state as of that
                 stream token, we raise a 403 SynapseError. If None, returns the current
                 state based on the current_state_events table.
-            is_guest: whether this user is a guest
         Returns:
             A list of dicts representing state events. [{}, {}, {}]
         Raises:
@@ -185,6 +189,7 @@ class MessageHandler:
             members of this room.
         """
         state_filter = state_filter or StateFilter.all()
+        user_id = requester.user.to_string()
 
         if at_token:
             last_event_id = (
@@ -217,7 +222,7 @@ class MessageHandler:
                 membership,
                 membership_event_id,
             ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
 
             if membership == Membership.JOIN:
@@ -311,30 +316,42 @@ class MessageHandler:
         Returns:
             A dict of user_id to profile info
         """
-        user_id = requester.user.to_string()
         if not requester.app_service:
             # We check AS auth after fetching the room membership, as it
             # requires us to pull out all joined members anyway.
             membership, _ = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
             if membership != Membership.JOIN:
-                raise NotImplementedError(
-                    "Getting joined members after leaving is not implemented"
+                raise SynapseError(
+                    code=403,
+                    errcode=Codes.FORBIDDEN,
+                    msg="Getting joined members while not being a current member of the room is forbidden.",
                 )
 
-        users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
+        users_with_profile = (
+            await self._state_storage_controller.get_users_in_room_with_profiles(
+                room_id
+            )
+        )
 
         # If this is an AS, double check that they are allowed to see the members.
         # This can either be because the AS user is in the room or because there
         # is a user in the room that the AS is "interested in"
-        if requester.app_service and user_id not in users_with_profile:
+        if (
+            requester.app_service
+            and requester.user.to_string() not in users_with_profile
+        ):
             for uid in users_with_profile:
                 if requester.app_service.is_interested_in_user(uid):
                     break
             else:
                 # Loop fell through, AS has no interested users in room
-                raise AuthError(403, "Appservice not in room")
+                raise UnstableSpecAuthError(
+                    403,
+                    "Appservice not in room",
+                    errcode=Codes.NOT_JOINED,
+                )
 
         return {
             user_id: {
@@ -463,6 +480,7 @@ class EventCreationHandler:
         )
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -734,18 +752,12 @@ class EventCreationHandler:
         if builder.type == EventTypes.Member:
             membership = builder.content.get("membership", None)
             if membership == Membership.JOIN:
-                return await self._is_server_notices_room(builder.room_id)
+                return await self.store.is_server_notice_room(builder.room_id)
             elif membership == Membership.LEAVE:
                 # the user is always allowed to leave (but not kick people)
                 return builder.state_key == requester.user.to_string()
         return False
 
-    async def _is_server_notices_room(self, room_id: str) -> bool:
-        if self.config.servernotices.server_notices_mxid is None:
-            return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self.config.servernotices.server_notices_mxid in user_ids
-
     async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
 
@@ -1134,6 +1146,10 @@ class EventCreationHandler:
             context = await self.state.compute_event_context(
                 event,
                 state_ids_before_event=state_map_for_event,
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                partial_state=False,
             )
         else:
             context = await self.state.compute_event_context(event)
@@ -1358,9 +1374,10 @@ class EventCreationHandler:
         # and `state_groups` because they have `prev_events` that aren't persisted yet
         # (historical messages persisted in reverse-chronological order).
         if not event.internal_metadata.is_historical():
-            await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                event, context
-            )
+            with opentracing.start_active_span("calculate_push_actions"):
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                    event, context
+                )
 
         try:
             # If we're a worker we need to hit out to the master.
@@ -1444,7 +1461,13 @@ class EventCreationHandler:
             if state_entry.state_group in self._external_cache_joined_hosts_updates:
                 return
 
-            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+            state = await state_entry.get_state(
+                self._storage_controllers.state, StateFilter.all()
+            )
+            with opentracing.start_active_span("get_joined_hosts"):
+                joined_hosts = await self.store.get_joined_hosts(
+                    event.room_id, state, state_entry
+                )
 
             # Note that the expiry times must be larger than the expiry time in
             # _external_cache_joined_hosts_updates.
@@ -1545,6 +1568,16 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            (
+                current_membership,
+                _,
+            ) = await self.store.get_local_current_membership_for_user_in_room(
+                event.state_key, event.room_id
+            )
+            if current_membership != Membership.JOIN:
+                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
         await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
@@ -1844,13 +1877,8 @@ class EventCreationHandler:
 
         # For each room we need to find a joined member we can use to send
         # the dummy event with.
-        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-        members = await self.state.get_current_users_in_room(
-            room_id, latest_event_ids=latest_event_ids
-        )
+        members = await self.store.get_local_users_in_room(room_id)
         for user_id in members:
-            if not self.hs.is_mine_id(user_id):
-                continue
             requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
@@ -1861,7 +1889,6 @@ class EventCreationHandler:
                         "room_id": room_id,
                         "sender": user_id,
                     },
-                    prev_event_ids=latest_event_ids,
                 )
 
                 event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6262a35822..1f83bab836 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,7 +24,9 @@ from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.room import ShutdownRoomResponse
+from synapse.logging.opentracing import trace
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamKeyType
@@ -158,11 +160,9 @@ class PaginationHandler:
         self._retention_allowed_lifetime_max = (
             hs.config.retention.retention_allowed_lifetime_max
         )
+        self._is_master = hs.config.worker.worker_app is None
 
-        if (
-            hs.config.worker.run_background_tasks
-            and hs.config.retention.retention_enabled
-        ):
+        if hs.config.retention.retention_enabled and self._is_master:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
@@ -416,6 +416,7 @@ class PaginationHandler:
 
             await self._storage_controllers.purge_events.purge_room(room_id)
 
+    @trace
     async def get_messages(
         self,
         requester: Requester,
@@ -423,6 +424,7 @@ class PaginationHandler:
         pagin_config: PaginationConfig,
         as_client_event: bool = True,
         event_filter: Optional[Filter] = None,
+        use_admin_priviledge: bool = False,
     ) -> JsonDict:
         """Get messages in a room.
 
@@ -432,10 +434,16 @@ class PaginationHandler:
             pagin_config: The pagination config rules to apply, if any.
             as_client_event: True to get events in client-server format.
             event_filter: Filter to apply to results or None
+            use_admin_priviledge: if `True`, return all events, regardless
+                of whether `user` has access to them. To be used **ONLY**
+                from the admin API.
 
         Returns:
             Pagination API results
         """
+        if use_admin_priviledge:
+            await assert_user_is_admin(self.auth, requester)
+
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
@@ -458,12 +466,14 @@ class PaginationHandler:
         room_token = from_token.room_key
 
         async with self.pagination_lock.read(room_id):
-            (
-                membership,
-                member_event_id,
-            ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
-            )
+            (membership, member_event_id) = (None, None)
+            if not use_admin_priviledge:
+                (
+                    membership,
+                    member_event_id,
+                ) = await self.auth.check_user_in_room_or_world_readable(
+                    room_id, requester, allow_departed_users=True
+                )
 
             if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -475,7 +485,7 @@ class PaginationHandler:
                         room_id, room_token.stream
                     )
 
-                if membership == Membership.LEAVE:
+                if not use_admin_priviledge and membership == Membership.LEAVE:
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
@@ -528,12 +538,13 @@ class PaginationHandler:
         if event_filter:
             events = await event_filter.filter(events)
 
-        events = await filter_events_for_client(
-            self._storage_controllers,
-            user_id,
-            events,
-            is_peeking=(member_event_id is None),
-        )
+        if not use_admin_priviledge:
+            events = await filter_events_for_client(
+                self._storage_controllers,
+                user_id,
+                events,
+                is_peeking=(member_event_id is None),
+            )
 
         # if after the filter applied there are no more events
         # return immediately - but there might be more in next_token batch
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..4e575ffbaa 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    FrozenSet,
     Generator,
     Iterable,
     List,
@@ -42,7 +41,6 @@ from typing import (
     Set,
     Tuple,
     Type,
-    Union,
 )
 
 from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
 from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 # doesn't return. C.f. #5503.
                 return [], max_token
 
-            # Figure out which other users this user should receive updates for
-            users_interested_in = await self._get_interested_in(user, explicit_room_id)
+            # Figure out which other users this user should explicitly receive
+            # updates for
+            additional_users_interested_in = (
+                await self.get_presence_router().get_interested_users(user.to_string())
+            )
 
             # We have a set of users that we're interested in the presence of. We want to
             # cross-reference that with the users that have actually changed their presence.
 
             # Check whether this user should see all user updates
 
-            if users_interested_in == PresenceRouter.ALL_USERS:
+            if additional_users_interested_in == PresenceRouter.ALL_USERS:
                 # Provide presence state for all users
                 presence_updates = await self._filter_all_presence_updates_for_user(
                     user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 return presence_updates, max_token
 
             # Make mypy happy. users_interested_in should now be a set
-            assert not isinstance(users_interested_in, str)
+            assert not isinstance(additional_users_interested_in, str)
+
+            # We always care about our own presence.
+            additional_users_interested_in.add(user_id)
+
+            if explicit_room_id:
+                user_ids = await self.store.get_users_in_room(explicit_room_id)
+                additional_users_interested_in.update(user_ids)
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+            interested_and_updated_users: Collection[str]
 
             if from_key is not None:
                 # First get all users that have had a presence update
                 updated_users = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                # Use a slightly-optimised method for processing smaller sets of updates.
-                if updated_users is not None and len(updated_users) < 500:
-                    # For small deltas, it's quicker to get all changes and then
-                    # cross-reference with the users we're interested in
+                if updated_users is not None:
+                    # If we have the full list of changes for presence we can
+                    # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
-                    for other_user_id in updated_users:
-                        if other_user_id in users_interested_in:
-                            # mypy thinks this variable could be a FrozenSet as it's possibly set
-                            # to one in the `get_entities_changed` call below, and `add()` is not
-                            # method on a FrozenSet. That doesn't affect us here though, as
-                            # `interested_and_updated_users` is clearly a set() above.
-                            interested_and_updated_users.add(other_user_id)  # type: ignore
+
+                    sharing_users = await self.store.do_users_share_a_room(
+                        user_id, updated_users
+                    )
+
+                    interested_and_updated_users = (
+                        sharing_users.union(additional_users_interested_in)
+                    ).intersection(updated_users)
+
                 else:
                     # Too many possible updates. Find all users we can see and check
                     # if any of them have changed.
                     get_updates_counter.labels("full").inc()
 
+                    users_interested_in = (
+                        await self.store.get_users_who_share_room_with_user(user_id)
+                    )
+                    users_interested_in.update(additional_users_interested_in)
+
                     interested_and_updated_users = (
                         stream_change_cache.get_entities_changed(
                             users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
             else:
                 # No from_key has been specified. Return the presence for all users
                 # this user is interested in
-                interested_and_updated_users = users_interested_in
+                interested_and_updated_users = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+                interested_and_updated_users.update(additional_users_interested_in)
 
             # Retrieve the current presence state for each user
             users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
     def get_current_key(self) -> int:
         return self.store.get_current_presence_token()
 
-    @cached(num_args=2, cache_context=True)
-    async def _get_interested_in(
-        self,
-        user: UserID,
-        explicit_room_id: Optional[str] = None,
-        cache_context: Optional[_CacheContext] = None,
-    ) -> Union[Set[str], str]:
-        """Returns the set of users that the given user should see presence
-        updates for.
-
-        Args:
-            user: The user to retrieve presence updates for.
-            explicit_room_id: The users that are in the room will be returned.
-
-        Returns:
-            A set of user IDs to return presence updates for, or "ALL" to return all
-            known updates.
-        """
-        user_id = user.to_string()
-        users_interested_in = set()
-        users_interested_in.add(user_id)  # So that we receive our own presence
-
-        # cache_context isn't likely to ever be None due to the @cached decorator,
-        # but we can't have a non-optional argument after the optional argument
-        # explicit_room_id either. Assert cache_context is not None so we can use it
-        # without mypy complaining.
-        assert cache_context
-
-        # Check with the presence router whether we should poll additional users for
-        # their presence information
-        additional_users = await self.get_presence_router().get_interested_users(
-            user.to_string()
-        )
-        if additional_users == PresenceRouter.ALL_USERS:
-            # If the module requested that this user see the presence updates of *all*
-            # users, then simply return that instead of calculating what rooms this
-            # user shares
-            return PresenceRouter.ALL_USERS
-
-        # Add the additional users from the router
-        users_interested_in.update(additional_users)
-
-        # Find the users who share a room with this user
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-        users_interested_in.update(users_who_share_room)
-
-        if explicit_room_id:
-            user_ids = await self.store.get_users_in_room(
-                explicit_room_id, on_invalidate=cache_context.invalidate
-            )
-            users_interested_in.update(user_ids)
-
-        return users_interested_in
-
 
 def handle_timeouts(
     user_states: List[UserPresenceState],
@@ -2091,8 +2051,7 @@ async def get_interested_remotes(
     )
 
     for room_id, states in room_ids_to_states.items():
-        user_ids = await store.get_users_in_room(room_id)
-        hosts = {get_domain_from_id(user_id) for user_id in user_ids}
+        hosts = await store.get_current_hosts_in_room(room_id)
         for host in hosts:
             hosts_and_states.setdefault(host, set()).update(states)
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d2882b0a..d2bdb9c8be 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -256,10 +256,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             room_ids, from_key=from_key, to_key=to_key
         )
 
-        if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private_receipts(
-                events, user.to_string()
-            )
+        events = ReceiptEventSource.filter_out_private_receipts(
+            events, user.to_string()
+        )
 
         return events, to_key
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c77d181722..20ec22105a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -29,7 +29,13 @@ from synapse.api.constants import (
     JoinRules,
     LoginType,
 )
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    ConsentNotGivenError,
+    InvalidClientTokenError,
+    SynapseError,
+)
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
@@ -180,10 +186,7 @@ class RegistrationHandler:
                 )
             if guest_access_token:
                 user_data = await self.auth.get_user_by_access_token(guest_access_token)
-                if (
-                    not user_data.is_guest
-                    or UserID.from_string(user_data.user_id).localpart != localpart
-                ):
+                if not user_data.is_guest or user_data.user.localpart != localpart:
                     raise AuthError(
                         403,
                         "Cannot register taken user ID without valid guest "
@@ -618,7 +621,7 @@ class RegistrationHandler:
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
         if not service:
-            raise AuthError(403, "Invalid application service token.")
+            raise InvalidClientTokenError()
         if not service.is_interested_in_user(user_id):
             raise SynapseError(
                 400,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0b63cd2186..28d7093f08 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -19,6 +19,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -73,7 +74,6 @@ class RelationsHandler:
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -89,7 +89,6 @@ class RelationsHandler:
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -104,7 +103,7 @@ class RelationsHandler:
 
         # TODO Properly handle a user leaving a room.
         (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         # This gets the original event and checks that a) the event exists and
@@ -122,7 +121,6 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            aggregation_key=aggregation_key,
             limit=limit,
             direction=direction,
             from_token=from_token,
@@ -364,6 +362,7 @@ class RelationsHandler:
 
         return results
 
+    @trace
     async def get_bundled_aggregations(
         self, events: Iterable[EventBase], user_id: str
     ) -> Dict[str, BundledAggregations]:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a54f163c0a..33e9a87002 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -19,6 +19,7 @@ import math
 import random
 import string
 from collections import OrderedDict
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -60,7 +61,6 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
@@ -705,8 +705,8 @@ class RoomCreationHandler:
                 was, requested, `room_alias`. Secondly, the stream_id of the
                 last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, or something went
-            horribly wrong.
+            SynapseError if the room ID couldn't be stored, 3pid invitation config
+            validation failed, or something went horribly wrong.
             ResourceLimitError if server is blocked to some resource being
             exceeded
         """
@@ -721,7 +721,7 @@ class RoomCreationHandler:
             # allow the server notices mxid to create rooms
             is_requester_admin = True
         else:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
 
         # Let the third party rules modify the room creation config if needed, or abort
         # the room creation entirely with an exception.
@@ -732,6 +732,19 @@ class RoomCreationHandler:
         invite_3pid_list = config.get("invite_3pid", [])
         invite_list = config.get("invite", [])
 
+        # validate each entry for correctness
+        for invite_3pid in invite_3pid_list:
+            if not all(
+                key in invite_3pid
+                for key in ("medium", "address", "id_server", "id_access_token")
+            ):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "all of `medium`, `address`, `id_server` and `id_access_token` "
+                    "are required when making a 3pid invite",
+                    Codes.MISSING_PARAM,
+                )
+
         if not is_requester_admin:
             spam_check = await self.spam_checker.user_may_create_room(user_id)
             if spam_check != NOT_SPAM:
@@ -889,7 +902,11 @@ class RoomCreationHandler:
         # override any attempt to set room versions via the creation_content
         creation_content["room_version"] = room_version.identifier
 
-        last_stream_id = await self._send_events_for_new_room(
+        (
+            last_stream_id,
+            last_sent_event_id,
+            depth,
+        ) = await self._send_events_for_new_room(
             requester,
             room_id,
             preset_config=preset_config,
@@ -905,7 +922,7 @@ class RoomCreationHandler:
         if "name" in config:
             name = config["name"]
             (
-                _,
+                name_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -917,12 +934,16 @@ class RoomCreationHandler:
                     "content": {"name": name},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = name_event.event_id
+            depth += 1
 
         if "topic" in config:
             topic = config["topic"]
             (
-                _,
+                topic_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -934,7 +955,11 @@ class RoomCreationHandler:
                     "content": {"topic": topic},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = topic_event.event_id
+            depth += 1
 
         # we avoid dropping the lock between invites, as otherwise joins can
         # start coming in and making the createRoom slow.
@@ -949,7 +974,7 @@ class RoomCreationHandler:
 
             for invitee in invite_list:
                 (
-                    _,
+                    member_event_id,
                     last_stream_id,
                 ) = await self.room_member_handler.update_membership_locked(
                     requester,
@@ -959,16 +984,23 @@ class RoomCreationHandler:
                     ratelimit=False,
                     content=content,
                     new_room=True,
+                    prev_event_ids=[last_sent_event_id],
+                    depth=depth,
                 )
+                last_sent_event_id = member_event_id
+                depth += 1
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
-            id_access_token = invite_3pid.get("id_access_token")  # optional
+            id_access_token = invite_3pid["id_access_token"]
             address = invite_3pid["address"]
             medium = invite_3pid["medium"]
             # Note that do_3pid_invite can raise a  ShadowBanError, but this was
             # handled above by emptying invite_3pid_list.
-            last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
+            (
+                member_event_id,
+                last_stream_id,
+            ) = await self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
                 medium,
@@ -977,7 +1009,11 @@ class RoomCreationHandler:
                 requester,
                 txn_id=None,
                 id_access_token=id_access_token,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = member_event_id
+            depth += 1
 
         result = {"room_id": room_id}
 
@@ -1005,20 +1041,22 @@ class RoomCreationHandler:
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
         ratelimit: bool = True,
-    ) -> int:
+    ) -> Tuple[int, str, int]:
         """Sends the initial events into a new room.
 
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
         Returns:
-            The stream_id of the last event persisted.
+            A tuple containing the stream ID, event ID and depth of the last
+            event sent to the room.
         """
 
         creator_id = creator.user.to_string()
 
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
 
+        depth = 1
         last_sent_event_id: Optional[str] = None
 
         def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
@@ -1031,6 +1069,7 @@ class RoomCreationHandler:
 
         async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
             nonlocal last_sent_event_id
+            nonlocal depth
 
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
@@ -1047,9 +1086,11 @@ class RoomCreationHandler:
                 # Note: we don't pass state_event_ids here because this triggers
                 # an additional query per event to look them up from the events table.
                 prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+                depth=depth,
             )
 
             last_sent_event_id = sent_event.event_id
+            depth += 1
 
             return last_stream_id
 
@@ -1075,6 +1116,7 @@ class RoomCreationHandler:
             content=creator_join_profile,
             new_room=True,
             prev_event_ids=[last_sent_event_id],
+            depth=depth,
         )
         last_sent_event_id = member_event_id
 
@@ -1168,7 +1210,7 @@ class RoomCreationHandler:
                 content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
             )
 
-        return last_sent_stream_id
+        return last_sent_stream_id, last_sent_event_id, depth
 
     def _generate_room_id(self) -> str:
         """Generates a random room ID.
@@ -1250,13 +1292,16 @@ class RoomContextHandler:
         """
         user = requester.user
         if use_admin_priviledge:
-            await assert_user_is_admin(self.auth, requester.user)
+            await assert_user_is_admin(self.auth, requester)
 
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = await self.store.get_users_in_room(room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         async def filter_evts(events: List[EventBase]) -> List[EventBase]:
             if use_admin_priviledge:
@@ -1355,6 +1400,7 @@ class TimestampLookupHandler:
         self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.federation_client = hs.get_federation_client()
+        self.federation_event_handler = hs.get_federation_event_handler()
         self._storage_controllers = hs.get_storage_controllers()
 
     async def get_event_for_timestamp(
@@ -1429,17 +1475,16 @@ class TimestampLookupHandler:
                 timestamp,
             )
 
-            # Find other homeservers from the given state in the room
-            curr_state = await self._storage_controllers.state.get_current_state(
-                room_id
+            likely_domains = (
+                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
             )
-            curr_domains = get_domains_from_state(curr_state)
-            likely_domains = [
-                domain for domain, depth in curr_domains if domain != self.server_name
-            ]
 
             # Loop through each homeserver candidate until we get a succesful response
             for domain in likely_domains:
+                # We don't want to ask our own server for information we don't have
+                if domain == self.server_name:
+                    continue
+
                 try:
                     remote_response = await self.federation_client.timestamp_to_event(
                         domain, room_id, timestamp, direction
@@ -1450,38 +1495,68 @@ class TimestampLookupHandler:
                         remote_response,
                     )
 
-                    # TODO: Do we want to persist this as an extremity?
-                    # TODO: I think ideally, we would try to backfill from
-                    # this event and run this whole
-                    # `get_event_for_timestamp` function again to make sure
-                    # they didn't give us an event from their gappy history.
                     remote_event_id = remote_response.event_id
-                    origin_server_ts = remote_response.origin_server_ts
+                    remote_origin_server_ts = remote_response.origin_server_ts
+
+                    # Backfill this event so we can get a pagination token for
+                    # it with `/context` and paginate `/messages` from this
+                    # point.
+                    #
+                    # TODO: The requested timestamp may lie in a part of the
+                    #   event graph that the remote server *also* didn't have,
+                    #   in which case they will have returned another event
+                    #   which may be nowhere near the requested timestamp. In
+                    #   the future, we may need to reconcile that gap and ask
+                    #   other homeservers, and/or extend `/timestamp_to_event`
+                    #   to return events on *both* sides of the timestamp to
+                    #   help reconcile the gap faster.
+                    remote_event = (
+                        await self.federation_event_handler.backfill_event_id(
+                            domain, room_id, remote_event_id
+                        )
+                    )
+
+                    # XXX: When we see that the remote server is not trustworthy,
+                    # maybe we should not ask them first in the future.
+                    if remote_origin_server_ts != remote_event.origin_server_ts:
+                        logger.info(
+                            "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+                            domain,
+                            remote_event_id,
+                            remote_origin_server_ts,
+                            remote_event.origin_server_ts,
+                        )
 
                     # Only return the remote event if it's closer than the local event
                     if not local_event or (
-                        abs(origin_server_ts - timestamp)
+                        abs(remote_event.origin_server_ts - timestamp)
                         < abs(local_event.origin_server_ts - timestamp)
                     ):
-                        return remote_event_id, origin_server_ts
+                        logger.info(
+                            "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+                            remote_event_id,
+                            remote_event.origin_server_ts,
+                            timestamp,
+                            local_event.event_id if local_event else None,
+                            local_event.origin_server_ts if local_event else None,
+                        )
+                        return remote_event_id, remote_origin_server_ts
                 except (HttpResponseException, InvalidResponseError) as ex:
                     # Let's not put a high priority on some other homeserver
                     # failing to respond or giving a random response
                     logger.debug(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
                         domain,
                         type(ex).__name__,
                         ex,
                         ex.args,
                     )
-                except Exception as ex:
+                except Exception:
                     # But we do want to see some exceptions in our code
                     logger.warning(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
                         domain,
-                        type(ex).__name__,
-                        ex,
-                        ex.args,
+                        exc_info=True,
                     )
 
         # To appease mypy, we have to add both of these conditions to check for
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 29868eb743..bb0bdb8e6f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -182,7 +182,7 @@ class RoomListHandler:
                 == HistoryVisibility.WORLD_READABLE,
                 "guest_can_join": room["guest_access"] == "can_join",
                 "join_rule": room["join_rules"],
-                "org.matrix.msc3827.room_type": room["room_type"],
+                "room_type": room["room_type"],
             }
 
             # Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 04c44b2ccb..8d01f4bf2b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -32,6 +32,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -94,12 +95,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
+        # Tracks joins from local users to rooms this server isn't a member of.
+        # I.e. joins this server makes by requesting /make_join /send_join from
+        # another server.
         self._join_rate_limiter_remote = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
+        # TODO: find a better place to keep this Ratelimiter.
+        #   It needs to be
+        #    - written to by event persistence code
+        #    - written to by something which can snoop on replication streams
+        #    - read by the RoomMemberHandler to rate limit joins from local users
+        #    - read by the FederationServer to rate limit make_joins and send_joins from
+        #      other homeservers
+        #   I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+        self._join_rate_per_room_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+        )
 
         # Ratelimiter for invites, keyed by room (across all issuers, all
         # recipients).
@@ -136,6 +154,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
+        hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+    def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+        """Notify the rate limiter that a room join has occurred.
+
+        Use this to inform the RoomMemberHandler about joins that have either
+        - taken place on another homeserver, or
+        - on another worker in this homeserver.
+        Joins actioned by this worker should use the usual `ratelimit` method, which
+        checks the limit and increments the counter in one go.
+        """
+        self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
 
     @abc.abstractmethod
     async def _remote_join(
@@ -149,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """Try and join a room that this server is not in
 
         Args:
-            requester
+            requester: The user making the request, according to the access token.
             remote_room_hosts: List of servers that can be used to join via.
             room_id: Room that we are trying to join
             user: User who is trying to join
@@ -285,6 +315,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
@@ -315,6 +346,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
             txn_id:
             ratelimit:
@@ -370,6 +404,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             allow_no_prev_events=allow_no_prev_events,
             prev_event_ids=prev_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             require_consent=require_consent,
             outlier=outlier,
             historical=historical,
@@ -391,14 +426,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # up blocking profile updates.
             if newly_joined and ratelimit:
                 await self._join_rate_limiter_local.ratelimit(requester)
-
-        result_event = await self.event_creation_handler.handle_new_client_event(
-            requester,
-            event,
-            context,
-            extra_users=[target],
-            ratelimit=ratelimit,
-        )
+                await self._join_rate_per_room_limiter.ratelimit(
+                    requester, key=room_id, update=False
+                )
+        with opentracing.start_active_span("handle_new_client_event"):
+            result_event = await self.event_creation_handler.handle_new_client_event(
+                requester,
+                event,
+                context,
+                extra_users=[target],
+                ratelimit=ratelimit,
+            )
 
         if event.membership == Membership.LEAVE:
             if prev_member_event_id:
@@ -466,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -501,6 +540,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -523,24 +565,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # by application services), and then by room ID.
         async with self.member_as_limiter.queue(as_id):
             async with self.member_linearizer.queue(key):
-                result = await self.update_membership_locked(
-                    requester,
-                    target,
-                    room_id,
-                    action,
-                    txn_id=txn_id,
-                    remote_room_hosts=remote_room_hosts,
-                    third_party_signed=third_party_signed,
-                    ratelimit=ratelimit,
-                    content=content,
-                    new_room=new_room,
-                    require_consent=require_consent,
-                    outlier=outlier,
-                    historical=historical,
-                    allow_no_prev_events=allow_no_prev_events,
-                    prev_event_ids=prev_event_ids,
-                    state_event_ids=state_event_ids,
-                )
+                with opentracing.start_active_span("update_membership_locked"):
+                    result = await self.update_membership_locked(
+                        requester,
+                        target,
+                        room_id,
+                        action,
+                        txn_id=txn_id,
+                        remote_room_hosts=remote_room_hosts,
+                        third_party_signed=third_party_signed,
+                        ratelimit=ratelimit,
+                        content=content,
+                        new_room=new_room,
+                        require_consent=require_consent,
+                        outlier=outlier,
+                        historical=historical,
+                        allow_no_prev_events=allow_no_prev_events,
+                        prev_event_ids=prev_event_ids,
+                        state_event_ids=state_event_ids,
+                        depth=depth,
+                    )
 
         return result
 
@@ -562,6 +606,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -599,10 +644,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
         """
+
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -640,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 errcode=Codes.BAD_JSON,
             )
 
-        if "avatar_url" in content:
+        if "avatar_url" in content and content.get("avatar_url") is not None:
             if not await self.profile_handler.check_avatar_size_and_mime_type(
                 content["avatar_url"],
             ):
@@ -695,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 is_requester_admin = True
 
             else:
-                is_requester_admin = await self.auth.is_server_admin(requester.user)
+                is_requester_admin = await self.auth.is_server_admin(requester)
 
             if not is_requester_admin:
                 if self.config.server.block_non_admin_invites:
@@ -732,6 +781,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 allow_no_prev_events=allow_no_prev_events,
                 prev_event_ids=prev_event_ids,
                 state_event_ids=state_event_ids,
+                depth=depth,
                 content=content,
                 require_consent=require_consent,
                 outlier=outlier,
@@ -740,14 +790,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        current_state_ids = await self.state_handler.get_current_state_ids(
-            room_id, latest_event_ids=latest_event_ids
+        state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -787,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 old_membership == Membership.INVITE
                 and effective_membership_state == Membership.LEAVE
             ):
-                is_blocked = await self._is_server_notice_room(room_id)
+                is_blocked = await self.store.is_server_notice_room(room_id)
                 if is_blocked:
                     raise SynapseError(
                         HTTPStatus.FORBIDDEN,
@@ -798,11 +848,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(current_state_ids)
+        is_host_in_room = await self._is_host_in_room(state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(current_state_ids)
+                guest_can_join = await self._can_guest_join(state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -818,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 bypass_spam_checker = True
 
             else:
-                bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+                bypass_spam_checker = await self.auth.is_server_admin(requester)
 
             inviter = await self._get_inviter(target.to_string(), room_id)
             if (
@@ -840,13 +890,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
-                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+                target.to_string(),
+                room_id,
+                remote_room_hosts,
+                content,
+                is_host_in_room,
+                state_before_join,
             )
             if remote_join:
                 if ratelimit:
                     await self._join_rate_limiter_remote.ratelimit(
                         requester,
                     )
+                    await self._join_rate_per_room_limiter.ratelimit(
+                        requester,
+                        key=room_id,
+                        update=False,
+                    )
 
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
@@ -967,6 +1027,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             ratelimit=ratelimit,
             prev_event_ids=latest_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             content=content,
             require_consent=require_consent,
             outlier=outlier,
@@ -979,6 +1040,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
+        state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -998,6 +1060,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
+            state_before_join: The state before the join event (i.e. the resolution of
+                the states after its parent events).
 
         Returns:
             A tuple of:
@@ -1014,20 +1078,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
-            room_id
-        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            current_state_ids, room_version
+            state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1042,10 +1103,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(current_state_ids.values())
+        event_map = await self.store.get_events(state_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in current_state_ids.items()
+            for state_key, event_id in state_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1059,7 +1120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            current_state_ids, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1069,7 +1130,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            current_state_ids,
+            state_before_join,
         )
 
         return False, []
@@ -1321,8 +1382,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         id_server: str,
         requester: Requester,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[str, int]:
         """Invite a 3PID to a room.
 
         Args:
@@ -1334,16 +1397,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             requester: The user making the request.
             txn_id: The transaction ID this is part of, or None if this is not
                 part of a transaction.
-            id_access_token: The optional identity server access token.
+            id_access_token: Identity server access token.
+            depth: Override the depth used to order the event in the DAG.
+            prev_event_ids: The event IDs to use as the prev events
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
-             The new stream ID.
+            Tuple of event ID and stream ordering position
 
         Raises:
             ShadowBanError if the requester has been shadow-banned.
         """
         if self.config.server.block_non_admin_invites:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
             if not is_requester_admin:
                 raise SynapseError(
                     403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -1383,7 +1450,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # We don't check the invite against the spamchecker(s) here (through
             # user_may_invite) because we'll do it further down the line anyway (in
             # update_membership_locked).
-            _, stream_id = await self.update_membership(
+            event_id, stream_id = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
         else:
@@ -1402,7 +1469,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     additional_fields=spam_check[1],
                 )
 
-            stream_id = await self._make_and_store_3pid_invite(
+            event, stream_id = await self._make_and_store_3pid_invite(
                 requester,
                 id_server,
                 medium,
@@ -1411,9 +1478,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 inviter,
                 txn_id=txn_id,
                 id_access_token=id_access_token,
+                prev_event_ids=prev_event_ids,
+                depth=depth,
             )
+            event_id = event.event_id
 
-        return stream_id
+        return event_id, stream_id
 
     async def _make_and_store_3pid_invite(
         self,
@@ -1424,8 +1494,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         user: UserID,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
         room_state = await self._storage_controllers.state.get_current_state(
             room_id,
             StateFilter.from_types(
@@ -1518,8 +1590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             },
             ratelimit=False,
             txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            depth=depth,
         )
-        return stream_id
+        return event, stream_id
 
     async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
@@ -1543,12 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         return False
 
-    async def _is_server_notice_room(self, room_id: str) -> bool:
-        if self._server_notices_mxid is None:
-            return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self._server_notices_mxid in user_ids
-
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs: "HomeServer"):
@@ -1608,14 +1676,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise SynapseError(
+                404,
+                "Can't join remote room because no servers "
+                "that are in the room have been provided.",
+            )
 
         check_complexity = self.hs.config.server.limit_remote_rooms.enabled
         if (
             check_complexity
             and self.hs.config.server.limit_remote_rooms.admins_can_join
         ):
-            check_complexity = not await self.auth.is_server_admin(user)
+            check_complexity = not await self.store.is_server_admin(user)
 
         if check_complexity:
             # Fetch the room complexity
@@ -1845,8 +1917,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]:
             raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
 
-        if membership:
-            await self.store.forget(user_id, room_id)
+        # In normal case this call is only required if `membership` is not `None`.
+        # But: After the last member had left the room, the background update
+        # `_background_remove_left_rooms` is deleting rows related to this room from
+        # the table `current_state_events` and `get_current_state_events` is `None`.
+        await self.store.forget(user_id, room_id)
 
 
 def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 13098f56ed..ebd445adca 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,11 +28,11 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.errors import (
-    AuthError,
     Codes,
     NotFoundError,
     StoreError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
 from synapse.api.ratelimiting import Ratelimiter
@@ -175,10 +175,11 @@ class RoomSummaryHandler:
 
         # First of all, check that the room is accessible.
         if not await self._is_local_room_accessible(requested_room_id, requester):
-            raise AuthError(
+            raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
                 % (requester, requested_room_id),
+                errcode=Codes.NOT_JOINED,
             )
 
         # If this is continuing a previous session, pull the persisted data.
@@ -452,7 +453,6 @@ class RoomSummaryHandler:
                 "type": e.type,
                 "state_key": e.state_key,
                 "content": e.content,
-                "room_id": e.room_id,
                 "sender": e.sender,
                 "origin_server_ts": e.origin_server_ts,
             }
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a305a66860..e2844799e8 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -23,10 +23,12 @@ from pkg_resources import parse_version
 
 import twisted
 from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
+from twisted.internet.interfaces import IOpenSSLContextFactory
+from twisted.internet.ssl import optionsForClientTLS
 from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.types import ISynapseReactor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
 
 
 async def _sendmail(
-    reactor: IReactorTCP,
+    reactor: ISynapseReactor,
     smtphost: str,
     smtpport: int,
     from_addr: str,
@@ -59,6 +61,7 @@ async def _sendmail(
     require_auth: bool = False,
     require_tls: bool = False,
     enable_tls: bool = True,
+    force_tls: bool = False,
 ) -> None:
     """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
 
@@ -73,8 +76,9 @@ async def _sendmail(
         password: password to give when authenticating
         require_auth: if auth is not offered, fail the request
         require_tls: if TLS is not offered, fail the reqest
-        enable_tls: True to enable TLS. If this is False and require_tls is True,
+        enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
            the request will fail.
+        force_tls: True to enable Implicit TLS.
     """
     msg = BytesIO(msg_bytes)
     d: "Deferred[object]" = Deferred()
@@ -105,13 +109,23 @@ async def _sendmail(
         # set to enable TLS.
         factory = build_sender_factory(hostname=smtphost if enable_tls else None)
 
-    reactor.connectTCP(
-        smtphost,
-        smtpport,
-        factory,
-        timeout=30,
-        bindAddress=None,
-    )
+    if force_tls:
+        reactor.connectSSL(
+            smtphost,
+            smtpport,
+            factory,
+            optionsForClientTLS(smtphost),
+            timeout=30,
+            bindAddress=None,
+        )
+    else:
+        reactor.connectTCP(
+            smtphost,
+            smtpport,
+            factory,
+            timeout=30,
+            bindAddress=None,
+        )
 
     await make_deferred_yieldable(d)
 
@@ -132,6 +146,7 @@ class SendEmailHandler:
         self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
         self._require_transport_security = hs.config.email.require_transport_security
         self._enable_tls = hs.config.email.enable_smtp_tls
+        self._force_tls = hs.config.email.force_tls
 
         self._sendmail = _sendmail
 
@@ -189,4 +204,5 @@ class SendEmailHandler:
             require_auth=self._smtp_user is not None,
             require_tls=self._require_transport_security,
             enable_tls=self._enable_tls,
+            force_tls=self._force_tls,
         )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d42a414c90..5293fa4d0e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,20 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    AbstractSet,
+    Any,
+    Collection,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -89,7 +102,7 @@ class SyncConfig:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
     prev_batch: StreamToken
-    events: List[EventBase]
+    events: Sequence[EventBase]
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
@@ -507,10 +520,17 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `recents`, so partial state is only a problem when a membership
+                    # event turns up in `recents` but has not made it into the current
+                    # state.
                     current_state_ids_map = (
-                        await self._state_storage_controller.get_current_state_ids(
-                            room_id
-                        )
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -579,7 +599,13 @@ class SyncHandler:
                 if any(e.is_state() for e in loaded_recents):
                     # FIXME(faster_joins): We use the partial state here as
                     # we don't want to block `/sync` on finishing a lazy join.
-                    # Is this the correct way of doing it?
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `loaded_recents`, so partial state is only a problem when a
+                    # membership event turns up in `loaded_recents` but has not made it
+                    # into the current state.
                     current_state_ids_map = (
                         await self.store.get_partial_current_state_ids(room_id)
                     )
@@ -627,7 +653,10 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the room state after the given event
@@ -635,9 +664,14 @@ class SyncHandler:
         Args:
             event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id, state_filter=state_filter or StateFilter.all()
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
 
         # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -660,6 +694,7 @@ class SyncHandler:
         room_id: str,
         stream_position: StreamToken,
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the room state at a particular stream position
 
@@ -667,6 +702,9 @@ class SyncHandler:
             room_id: room for which to get state
             stream_position: point at which to get state
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
         """
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
@@ -678,7 +716,9 @@ class SyncHandler:
 
         if last_event_id:
             state = await self.get_state_after_event(
-                last_event_id, state_filter=state_filter or StateFilter.all()
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
             )
 
         else:
@@ -852,16 +892,26 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """Works out the difference in state between the start of the timeline
-        and the previous sync.
+        """Works out the difference in state between the end of the previous sync and
+        the start of the timeline.
 
         Args:
             room_id:
             batch: The timeline batch for the room that will be sent to the user.
             sync_config:
-            since_token: Token of the end of the previous batch. May be None.
+            since_token: Token of the end of the previous batch. May be `None`.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+                `lazy_load_members` still applies when `full_state` is `True`.
+
+        Returns:
+            The state to return in the sync response for the room.
+
+            Clients will overlay this onto the state at the end of the previous sync to
+            arrive at the state at the start of the timeline.
+
+            Clients will then overlay state events in the timeline to arrive at the
+            state at the end of the timeline, in preparation for the next sync.
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -869,8 +919,17 @@ class SyncHandler:
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
+            # The memberships needed for events in the timeline.
+            # Only calculated when `lazy_load_members` is on.
+            members_to_fetch: Optional[Set[str]] = None
+
+            # A dictionary mapping user IDs to the first event in the timeline sent by
+            # them. Only calculated when `lazy_load_members` is on.
+            first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
 
-            members_to_fetch = None
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
+            timeline_state: StateMap[str]
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -881,10 +940,23 @@ class SyncHandler:
                 # We only request state for the members needed to display the
                 # timeline:
 
-                members_to_fetch = {
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in batch.events
-                }
+                timeline_state = {}
+
+                members_to_fetch = set()
+                first_event_by_sender_map = {}
+                for event in batch.events:
+                    # Build the map from user IDs to the first timeline event they sent.
+                    if event.sender not in first_event_by_sender_map:
+                        first_event_by_sender_map[event.sender] = event
+
+                    # We need the event's sender, unless their membership was in a
+                    # previous timeline event.
+                    if (EventTypes.Member, event.sender) not in timeline_state:
+                        members_to_fetch.add(event.sender)
+                    # FIXME: we also care about invite targets etc.
+
+                    if event.is_state():
+                        timeline_state[(event.type, event.state_key)] = event.event_id
 
                 if full_state:
                     # always make sure we LL ourselves so we know we're in the room
@@ -894,55 +966,80 @@ class SyncHandler:
                     members_to_fetch.add(sync_config.user.to_string())
 
                 state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+                # We are happy to use partial state to compute the `/sync` response.
+                # Since partial state may not include the lazy-loaded memberships we
+                # require, we fix up the state response afterwards with memberships from
+                # auth events.
+                await_full_state = False
             else:
+                timeline_state = {
+                    (event.type, event.state_key): event.event_id
+                    for event in batch.events
+                    if event.is_state()
+                }
+
                 state_filter = StateFilter.all()
+                await_full_state = True
 
-            timeline_state = {
-                (event.type, event.state_key): event.event_id
-                for event in batch.events
-                if event.is_state()
-            }
+            # Now calculate the state to return in the sync response for the room.
+            # This is more or less the change in state between the end of the previous
+            # sync's timeline and the start of the current sync's timeline.
+            # See the docstring above for details.
+            state_ids: StateMap[str]
 
             if full_state:
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
-                    state_ids = (
+                    state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                 else:
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
-                    state_ids = current_state_ids
+                    state_at_timeline_start = state_at_timeline_end
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state_ids,
-                    previous={},
-                    current=current_state_ids,
+                    timeline_start=state_at_timeline_start,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end={},
                     lazy_load_members=lazy_load_members,
                 )
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_start = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 # for now, we disable LL for gappy syncs - see
@@ -964,28 +1061,35 @@ class SyncHandler:
                 # is indeed the case.
                 assert since_token is not None
                 state_at_previous_sync = await self.get_state_at(
-                    room_id, stream_position=since_token, state_filter=state_filter
+                    room_id,
+                    stream_position=since_token,
+                    state_filter=state_filter,
+                    await_full_state=await_full_state,
                 )
 
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
-                    # Its not clear how we get here, but empirically we do
-                    # (#5407). Logging has been added elsewhere to try and
-                    # figure out where this state comes from.
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    # We can get here if the user has ignored the senders of all
+                    # the recent events.
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
-                    previous=state_at_previous_sync,
-                    current=current_state_ids,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end=state_at_previous_sync,
                     # we have to include LL members in case LL initial sync missed them
                     lazy_load_members=lazy_load_members,
                 )
@@ -1008,8 +1112,30 @@ class SyncHandler:
                                 (EventTypes.Member, member)
                                 for member in members_to_fetch
                             ),
+                            await_full_state=False,
                         )
 
+            # If we only have partial state for the room, `state_ids` may be missing the
+            # memberships we wanted. We attempt to find some by digging through the auth
+            # events of timeline events.
+            if lazy_load_members and await self.store.is_partial_state_room(room_id):
+                assert members_to_fetch is not None
+                assert first_event_by_sender_map is not None
+
+                additional_state_ids = (
+                    await self._find_missing_partial_state_memberships(
+                        room_id, members_to_fetch, first_event_by_sender_map, state_ids
+                    )
+                )
+                state_ids = {**state_ids, **additional_state_ids}
+
+            # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+            # the memberships of all event senders in the timeline. This is because we
+            # may not have sent the memberships in a previous sync.
+
+            # When `include_redundant_members` is on, we send all the lazy-loaded
+            # memberships of event senders. Otherwise we make an effort to limit the set
+            # of memberships we send to those that we have not already sent to this client.
             if lazy_load_members and not include_redundant_members:
                 cache_key = (sync_config.user.to_string(), sync_config.device_id)
                 cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1051,6 +1177,99 @@ class SyncHandler:
             if e.type != EventTypes.Aliases  # until MSC2261 or alternative solution
         }
 
+    async def _find_missing_partial_state_memberships(
+        self,
+        room_id: str,
+        members_to_fetch: Collection[str],
+        events_with_membership_auth: Mapping[str, EventBase],
+        found_state_ids: StateMap[str],
+    ) -> StateMap[str]:
+        """Finds missing memberships from a set of auth events and returns them as a
+        state map.
+
+        Args:
+            room_id: The partial state room to find the remaining memberships for.
+            members_to_fetch: The memberships to find.
+            events_with_membership_auth: A mapping from user IDs to events whose auth
+                events are known to contain their membership.
+            found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+                memberships that have been previously found. Entries in
+                `members_to_fetch` that have a membership in `found_state_ids` are
+                ignored.
+
+        Returns:
+            A dict from ("m.room.member", state_key) -> state_event_id, containing the
+            memberships missing from `found_state_ids`.
+
+        Raises:
+            KeyError: if `events_with_membership_auth` does not have an entry for a
+                missing membership. Memberships in `found_state_ids` do not need an
+                entry in `events_with_membership_auth`.
+        """
+        additional_state_ids: MutableStateMap[str] = {}
+
+        # Tracks the missing members for logging purposes.
+        missing_members = set()
+
+        # Identify memberships missing from `found_state_ids` and pick out the auth
+        # events in which to look for them.
+        auth_event_ids: Set[str] = set()
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            missing_members.add(member)
+            event_with_membership_auth = events_with_membership_auth[member]
+            auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+        auth_events = await self.store.get_events(auth_event_ids)
+
+        # Run through the missing memberships once more, picking out the memberships
+        # from the pile of auth events we have just fetched.
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            event_with_membership_auth = events_with_membership_auth[member]
+
+            # Dig through the auth events to find the desired membership.
+            for auth_event_id in event_with_membership_auth.auth_event_ids():
+                # We only store events once we have all their auth events,
+                # so the auth event must be in the pile we have just
+                # fetched.
+                auth_event = auth_events[auth_event_id]
+
+                if (
+                    auth_event.type == EventTypes.Member
+                    and auth_event.state_key == member
+                ):
+                    missing_members.remove(member)
+                    additional_state_ids[
+                        (EventTypes.Member, member)
+                    ] = auth_event.event_id
+                    break
+
+        if missing_members:
+            # There really shouldn't be any missing memberships now. Either:
+            #  * we couldn't find an auth event, which shouldn't happen because we do
+            #    not persist events with persisting their auth events first, or
+            #  * the set of auth events did not contain a membership we wanted, which
+            #    means our caller didn't compute the events in `members_to_fetch`
+            #    correctly, or we somehow accepted an event whose auth events were
+            #    dodgy.
+            logger.error(
+                "Failed to find memberships for %s in partial state room "
+                "%s in the auth events of %s.",
+                missing_members,
+                room_id,
+                [
+                    events_with_membership_auth[member].event_id
+                    for member in missing_members
+                ],
+            )
+
+        return additional_state_ids
+
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
@@ -1195,10 +1414,10 @@ class SyncHandler:
     async def _generate_sync_entry_for_device_list(
         self,
         sync_result_builder: "SyncResultBuilder",
-        newly_joined_rooms: Set[str],
-        newly_joined_or_invited_or_knocked_users: Set[str],
-        newly_left_rooms: Set[str],
-        newly_left_users: Set[str],
+        newly_joined_rooms: AbstractSet[str],
+        newly_joined_or_invited_or_knocked_users: AbstractSet[str],
+        newly_left_rooms: AbstractSet[str],
+        newly_left_users: AbstractSet[str],
     ) -> DeviceListUpdates:
         """Generate the DeviceListUpdates section of sync
 
@@ -1216,8 +1435,7 @@ class SyncHandler:
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
-        # We're going to mutate these fields, so lets copy them rather than
-        # assume they won't get used later.
+        # Take a copy since these fields will be mutated later.
         newly_joined_or_invited_or_knocked_users = set(
             newly_joined_or_invited_or_knocked_users
         )
@@ -1417,8 +1635,8 @@ class SyncHandler:
     async def _generate_sync_entry_for_presence(
         self,
         sync_result_builder: "SyncResultBuilder",
-        newly_joined_rooms: Set[str],
-        newly_joined_or_invited_users: Set[str],
+        newly_joined_rooms: AbstractSet[str],
+        newly_joined_or_invited_users: AbstractSet[str],
     ) -> None:
         """Generates the presence portion of the sync response. Populates the
         `sync_result_builder` with the result.
@@ -1476,7 +1694,7 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         account_data_by_room: Dict[str, Dict[str, JsonDict]],
-    ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]:
+    ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
@@ -1536,15 +1754,13 @@ class SyncHandler:
         ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
+                sync_result_builder, ignored_users
             )
             tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            room_changes = await self._get_all_rooms(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
-            )
+            room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1623,13 +1839,14 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        excluded_rooms: List[str],
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
         This function is a first pass at generating the rooms part of the sync response.
         It determines which rooms have changed during the sync period, and categorises
-        them into four buckets: "knock", "invite", "join" and "leave".
+        them into four buckets: "knock", "invite", "join" and "leave". It also excludes
+        from that list any room that appears in the list of rooms to exclude from sync
+        results in the server configuration.
 
         1. Finds all membership changes for the user in the sync period (from
            `since_token` up to `now_token`).
@@ -1655,7 +1872,7 @@ class SyncHandler:
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
         membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key, excluded_rooms
+            user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1696,7 +1913,11 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(
+                    room_id,
+                    since_token,
+                    state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+                )
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
@@ -1722,7 +1943,13 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(
+                            room_id,
+                            since_token,
+                            state_filter=StateFilter.from_types(
+                                [(EventTypes.Member, user_id)]
+                            ),
+                        )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
@@ -1862,7 +2089,6 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        ignored_rooms: List[str],
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1884,7 +2110,7 @@ class SyncHandler:
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id,
             membership_list=Membership.LIST,
-            excluded_rooms=ignored_rooms,
+            excluded_rooms=self.rooms_to_exclude,
         )
 
         room_entries = []
@@ -2150,7 +2376,9 @@ class SyncHandler:
                 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, room_key: RoomStreamToken
+        self,
+        user_id: str,
+        room_key: RoomStreamToken,
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -2176,7 +2404,12 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
+        # We also need to check whether the room should be excluded from sync
+        # responses as per the homeserver config.
         for joined_room in joined_rooms:
+            if joined_room.room_id in self.rooms_to_exclude:
+                continue
+
             if not joined_room.event_pos.persisted_after(room_key):
                 joined_room_ids.add(joined_room.room_id)
                 continue
@@ -2188,10 +2421,10 @@ class SyncHandler:
                     joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(
+            user_ids_in_room = await self.state.get_current_user_ids_in_room(
                 joined_room.room_id, extrems
             )
-            if user_id in users_in_room:
+            if user_id in user_ids_in_room:
                 joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
@@ -2211,8 +2444,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
 def _calculate_state(
     timeline_contains: StateMap[str],
     timeline_start: StateMap[str],
-    previous: StateMap[str],
-    current: StateMap[str],
+    timeline_end: StateMap[str],
+    previous_timeline_end: StateMap[str],
     lazy_load_members: bool,
 ) -> StateMap[str]:
     """Works out what state to include in a sync response.
@@ -2220,45 +2453,50 @@ def _calculate_state(
     Args:
         timeline_contains: state in the timeline
         timeline_start: state at the start of the timeline
-        previous: state at the end of the previous sync (or empty dict
+        timeline_end: state at the end of the timeline
+        previous_timeline_end: state at the end of the previous sync (or empty dict
             if this is an initial sync)
-        current: state at the end of the timeline
         lazy_load_members: whether to return members from timeline_start
             or not.  assumes that timeline_start has already been filtered to
             include only the members the client needs to know about.
     """
-    event_id_to_key = {
-        e: key
-        for key, e in itertools.chain(
+    event_id_to_state_key = {
+        event_id: state_key
+        for state_key, event_id in itertools.chain(
             timeline_contains.items(),
-            previous.items(),
             timeline_start.items(),
-            current.items(),
+            timeline_end.items(),
+            previous_timeline_end.items(),
         )
     }
 
-    c_ids = set(current.values())
-    ts_ids = set(timeline_start.values())
-    p_ids = set(previous.values())
-    tc_ids = set(timeline_contains.values())
+    timeline_end_ids = set(timeline_end.values())
+    timeline_start_ids = set(timeline_start.values())
+    previous_timeline_end_ids = set(previous_timeline_end.values())
+    timeline_contains_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
     # as we may not have sent them to the client before.  We find these membership
     # events by filtering them out of timeline_start, which has already been filtered
     # to only include membership events for the senders in the timeline.
-    # In practice, we can do this by removing them from the p_ids list,
-    # which is the list of relevant state we know we have already sent to the client.
+    # In practice, we can do this by removing them from the previous_timeline_end_ids
+    # list, which is the list of relevant state we know we have already sent to the
+    # client.
     # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
-        p_ids.difference_update(
+        previous_timeline_end_ids.difference_update(
             e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (
+        (timeline_end_ids | timeline_start_ids)
+        - previous_timeline_end_ids
+        - timeline_contains_ids
+    )
 
-    return {event_id_to_key[e]: e for e in state_ids}
+    return {event_id_to_state_key[e]: e for e in state_ids}
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -2296,7 +2534,7 @@ class SyncResultBuilder:
     archived: List[ArchivedSyncResult] = attr.Factory(list)
     to_device: List[JsonDict] = attr.Factory(list)
 
-    def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
+    def calculate_user_changes(self) -> Tuple[AbstractSet[str], AbstractSet[str]]:
         """Work out which other users have joined or left rooms we are joined to.
 
         This data only is only useful for an incremental sync.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d104ea07fe..a4cd8b8f0c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str, timeout: int
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has started typing in %s", target_user_id, room_id)
 
@@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has stopped typing in %s", target_user_id, room_id)
 
@@ -364,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             )
             return
 
-        users = await self.store.get_users_in_room(room_id)
-        domains = {get_domain_from_id(u) for u in users}
+        domains = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
 
         if self.server_name in domains:
             logger.info("Got typing update from %s: %r", user_id, content)
@@ -489,8 +488,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
             handler = self.get_typing_handler()
 
             events = []
-            for room_id in handler._room_serials.keys():
-                if handler._room_serials[room_id] <= from_key:
+
+            # Work on a copy of things here as these may change in the handler while
+            # waiting for the AS `is_interested_in_room` call to complete.
+            # Shallow copy is safe as no nested data is present.
+            latest_room_serial = handler._latest_room_serial
+            room_serials = handler._room_serials.copy()
+
+            for room_id, serial in room_serials.items():
+                if serial <= from_key:
                     continue
 
                 if not await service.is_interested_in_room(room_id, self._main_store):
@@ -498,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
 
                 events.append(self._make_event_for(room_id))
 
-            return events, handler._latest_room_serial
+            return events, latest_room_serial
 
     async def get_new_events(
         self,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 05cebb5d4d..a744d68c64 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
 
-        # msisdns are currently always ThreepidBehaviour.REMOTE
+        # msisdns are currently always verified via the IS
         if medium == "msisdn":
             if not self.hs.config.registration.account_threepid_delegate_msisdn:
                 raise SynapseError(
@@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
                 threepid_creds,
             )
         elif medium == "email":
-            if (
-                self.hs.config.email.threepid_behaviour_email
-                == ThreepidBehaviour.REMOTE
-            ):
-                assert self.hs.config.registration.account_threepid_delegate_email
-                threepid = await identity_handler.threepid_from_creds(
-                    self.hs.config.registration.account_threepid_delegate_email,
-                    threepid_creds,
-                )
-            elif (
-                self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            ):
+            if self.hs.config.email.can_verify_email:
                 threepid = None
                 row = await self.store.get_threepid_validation_session(
                     medium,
@@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
         _BaseThreepidAuthChecker.__init__(self, hs)
 
     def is_enabled(self) -> bool:
-        return self.hs.config.email.threepid_behaviour_email in (
-            ThreepidBehaviour.REMOTE,
-            ThreepidBehaviour.LOCAL,
-        )
+        return self.hs.config.email.can_verify_email
 
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c63d068f74..3c35b1d2c7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -79,6 +79,7 @@ from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
 from synapse.util.metrics import Measure
+from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -479,6 +480,14 @@ class MatrixFederationHttpClient:
             RequestSendFailed: If there were problems connecting to the
                 remote, due to e.g. DNS failures, connection timeouts etc.
         """
+        # Validate server name and log if it is an invalid destination, this is
+        # partially to help track down code paths where we haven't validated before here
+        try:
+            parse_and_validate_server_name(request.destination)
+        except ValueError:
+            logger.exception(f"Invalid destination: {request.destination}.")
+            raise FederationDeniedError(request.destination)
+
         if timeout:
             _sec_timeout = timeout / 1000
         else:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index cf2d6f904b..6068a94b40 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,7 +33,6 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
-    TypeVar,
     Union,
 )
 
@@ -58,11 +57,13 @@ from synapse.api.errors import (
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.config.homeserver import HomeServerConfig
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
 from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
 from synapse.util import json_encoder
 from synapse.util.caches import intern_dict
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -93,77 +94,16 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-_cancellable_method_names = frozenset(
-    {
-        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
-        # methods
-        "on_GET",
-        "on_PUT",
-        "on_POST",
-        "on_DELETE",
-        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
-        # methods
-        "_async_render_GET",
-        "_async_render_PUT",
-        "_async_render_POST",
-        "_async_render_DELETE",
-        "_async_render_OPTIONS",
-        # `ReplicationEndpoint` methods
-        "_handle_request",
-    }
-)
-
-
-def cancellable(method: F) -> F:
-    """Marks a servlet method as cancellable.
-
-    Methods with this decorator will be cancelled if the client disconnects before we
-    finish processing the request.
-
-    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
-    the method. The `cancel()` call will propagate down to the `Deferred` that is
-    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
-    propagate up, as per normal exception handling.
-
-    Before applying this decorator to a new endpoint, you MUST recursively check
-    that all `await`s in the function are on `async` functions or `Deferred`s that
-    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
-    premature logging context closure, to stuck requests, to database corruption.
-
-    Usage:
-        class SomeServlet(RestServlet):
-            @cancellable
-            async def on_GET(self, request: SynapseRequest) -> ...:
-                ...
-    """
-    if method.__name__ not in _cancellable_method_names and not any(
-        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
-    ):
-        raise ValueError(
-            "@cancellable decorator can only be applied to servlet methods."
-        )
-
-    method.cancellable = True  # type: ignore[attr-defined]
-    return method
-
-
-def is_method_cancellable(method: Callable[..., Any]) -> bool:
-    """Checks whether a servlet method has the `@cancellable` flag."""
-    return getattr(method, "cancellable", False)
-
-
-def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+def return_json_error(
+    f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
+) -> None:
     """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
         # mypy doesn't understand that f.check asserts the type.
         exc: SynapseError = f.value  # type: ignore
         error_code = exc.code
-        error_dict = exc.error_dict()
-
+        error_dict = exc.error_dict(config)
         logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
     elif f.check(CancelledError):
         error_code = HTTP_STATUS_REQUEST_CANCELLED
@@ -387,7 +327,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
-            request.is_render_cancellable = is_method_cancellable(method_handler)
+            request.is_render_cancellable = is_function_cancellable(method_handler)
 
             raw_callback_return = method_handler(request)
 
@@ -450,7 +390,7 @@ class DirectServeJsonResource(_AsyncResource):
         request: SynapseRequest,
     ) -> None:
         """Implements _AsyncResource._send_error_response"""
-        return_json_error(f, request)
+        return_json_error(f, request, None)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -549,7 +489,7 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
-        request.is_render_cancellable = is_method_cancellable(callback)
+        request.is_render_cancellable = is_function_cancellable(callback)
 
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
@@ -575,6 +515,14 @@ class JsonResource(DirectServeJsonResource):
 
         return callback_return
 
+    def _send_error_response(
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response"""
+        return_json_error(f, request, self.hs.config)
+
 
 class DirectServeHtmlResource(_AsyncResource):
     """A resource that will call `self._async_on_<METHOD>` on new requests,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..80acbdcf3c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,9 +23,13 @@ from typing import (
     Optional,
     Sequence,
     Tuple,
+    Type,
+    TypeVar,
     overload,
 )
 
+from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
+from pydantic.error_wrappers import ErrorWrapper
 from typing_extensions import Literal
 
 from twisted.web.server import Request
@@ -694,6 +698,42 @@ def parse_json_object_from_request(
     return content
 
 
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+    request: Request, model_type: Type[Model]
+) -> Model:
+    """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+    validate using the given pydantic model.
+
+    Raises:
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_object_from_request(request, allow_empty_body=False)
+    try:
+        instance = model_type.parse_obj(content)
+    except ValidationError as e:
+        # Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
+        # more specific error if possible (which occasionally helps us to be spec-
+        # compliant) This is a bit awkward because the spec's error codes aren't very
+        # clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
+        errcode = Codes.BAD_JSON
+
+        raw_errors = e.raw_errors
+        if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
+            raw_error = raw_errors[0].exc
+            if isinstance(raw_error, MissingError):
+                errcode = Codes.MISSING_PARAM
+            elif isinstance(raw_error, PydanticValueError):
+                errcode = Codes.INVALID_PARAM
+
+        raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)
+
+    return instance
+
+
 def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..1155f3f610 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -226,7 +226,7 @@ class SynapseRequest(Request):
 
             # If this is a request where the target user doesn't match the user who
             # authenticated (e.g. and admin is puppetting a user) then we return both.
-            if self._requester.user.to_string() != authenticated_entity:
+            if requester != authenticated_entity:
                 return requester, authenticated_entity
 
             return requester, None
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 50c57940f9..ca2735dd6d 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -84,14 +84,13 @@ the function becomes the operation name for the span.
        return something_usual_and_useful
 
 
-Operation names can be explicitly set for a function by passing the
-operation name to ``trace``
+Operation names can be explicitly set for a function by using ``trace_with_opname``:
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace
+   from synapse.logging.opentracing import trace_with_opname
 
-   @trace(opname="a_better_operation_name")
+   @trace_with_opname("a_better_operation_name")
    def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
@@ -174,6 +173,7 @@ from typing import (
     Any,
     Callable,
     Collection,
+    ContextManager,
     Dict,
     Generator,
     Iterable,
@@ -183,6 +183,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -201,6 +203,9 @@ if TYPE_CHECKING:
 
 # Helper class
 
+# Matches the number suffix in an instance name like "matrix.org client_reader-8"
+STRIP_INSTANCE_NUMBER_SUFFIX_REGEX = re.compile(r"[_-]?\d+$")
+
 
 class _DummyTagNames:
     """wrapper of opentracings tags. We need to have them if we
@@ -293,6 +298,8 @@ class SynapseTags:
     # Whether the sync response has new data to be returned to the client.
     SYNC_RESULT = "sync.new_data"
 
+    INSTANCE_NAME = "instance_name"
+
     # incoming HTTP request ID  (as written in the logs)
     REQUEST_ID = "request_id"
 
@@ -308,6 +315,19 @@ class SynapseTags:
     # The name of the external cache
     CACHE_NAME = "cache.name"
 
+    # Used to tag function arguments
+    #
+    # Tag a named arg. The name of the argument should be appended to this prefix.
+    FUNC_ARG_PREFIX = "ARG."
+    # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+    FUNC_ARGS = "args"
+    # Tag keyword args
+    FUNC_KWARGS = "kwargs"
+
+    # Some intermediate result that's interesting to the function. The label for
+    # the result should be appended to this prefix.
+    RESULT_PREFIX = "RESULT."
+
 
 class SynapseBaggage:
     FORCE_TRACING = "synapse-force-tracing"
@@ -329,6 +349,7 @@ class _Sentinel(enum.Enum):
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -344,22 +365,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
         message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -404,9 +446,17 @@ def init_tracer(hs: "HomeServer") -> None:
 
     from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
 
+    # Instance names are opaque strings but by stripping off the number suffix,
+    # we can get something that looks like a "worker type", e.g.
+    # "client_reader-1" -> "client_reader" so we don't spread the traces across
+    # so many services.
+    instance_name_by_type = re.sub(
+        STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name()
+    )
+
     config = JaegerConfig(
         config=hs.config.tracing.jaeger_config,
-        service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}",
+        service_name=f"{hs.config.server.server_name} {instance_name_by_type}",
         scope_manager=LogContextScopeManager(),
         metrics_factory=PrometheusMetricsFactory(),
     )
@@ -465,7 +515,7 @@ def start_active_span(
     finish_on_close: bool = True,
     *,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -503,7 +553,7 @@ def start_active_span_follows_from(
     *,
     inherit_force_tracing: bool = False,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -718,7 +768,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -798,92 +850,155 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace(func=None, opname: Optional[str] = None):
+def _custom_sync_async_decorator(
+    func: Callable[P, R],
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function.
-    Sets the operation name to that of the function's or that given
-    as operation_name. See the module's doc string for usage
-    examples.
+    Decorates a function that is sync or async (coroutines), or that returns a Twisted
+    `Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
+
+    Example usage:
+    ```py
+    # Decorator to time the function and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
+            start_ts = time.time()
+            try:
+                yield
+            finally:
+                end_ts = time.time()
+                duration = end_ts - start_ts
+                logger.info("%s took %s seconds", func.__name__, duration)
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+    ```
+
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func):
-        if opentracing is None:
-            return func  # type: ignore[unreachable]
+    if inspect.iscoroutinefunction(func):
 
-        _opname = opname if opname else func.__name__
+        @wraps(func)
+        async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)  # type: ignore[misc]
 
-        if inspect.iscoroutinefunction(func):
+    else:
+        # The other case here handles both sync functions and those
+        # decorated with inlineDeferred.
+        @wraps(func)
+        def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
 
-            @wraps(func)
-            async def _trace_inner(*args, **kwargs):
-                with start_active_span(_opname):
-                    return await func(*args, **kwargs)
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
 
-        else:
-            # The other case here handles both sync functions and those
-            # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args, **kwargs):
-                scope = start_active_span(_opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
 
-                    return result
+                    result.addCallbacks(call_back, err_back)
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
 
-        return _trace_inner
+                    scope.__exit__(None, None, None)
 
-    if func:
-        return decorator(func)
-    else:
-        return decorator
+                return result
+
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
+
+    return _wrapper  # type: ignore[return-value]
+
+
+def trace_with_opname(
+    opname: str,
+    *,
+    tracer: Optional["opentracing.Tracer"] = None,
+) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+    See the module's doc string for usage examples.
+    """
+
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
+        with start_active_span(opname, tracer=tracer):
+            yield
+
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
+        if not opentracing:
+            return func
+
+        return _custom_sync_async_decorator(func, _wrapping_logic)
+
+    return _decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+    """
+    Decorator to trace a function.
+    Sets the operation name to that of the function's name.
+    See the module's doc string for usage examples.
+    """
+
+    return trace_with_opname(func.__name__)(func)
 
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
+
+    Args:
+        func: `func` is assumed to be a method taking a `self` parameter, or a
+            `classmethod` taking a `cls` parameter. In either case, a tag is not
+            created for this parameter.
     """
 
     if not opentracing:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
-        for i, arg in enumerate(argspec.args[1:]):
-            set_tag("ARG_" + arg, args[i])  # type: ignore[index]
-        set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
-        return func(*args, **kwargs)
+        # We use `[1:]` to skip the `self` object reference and `start=1` to
+        # make the index line up with `argspec.args`.
+        #
+        # FIXME: We could update this to handle any type of function by ignoring the
+        #   first argument only if it's named `self` or `cls`. This isn't fool-proof
+        #   but handles the idiomatic cases.
+        for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
+            set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+        set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :]))  # type: ignore[index]
+        set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
+        yield
 
-    return _tag_args_inner
+    return _custom_sync_async_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager
@@ -930,11 +1045,11 @@ def trace_servlet(
             # with JsonResource).
             scope.span.set_operation_name(request.request_metrics.name)
 
-            # set the tags *after* the servlet completes, in case it decided to
-            # prioritise the span (tags will get dropped on unprioritised spans)
             request_tags[
                 SynapseTags.REQUEST_TAG
             ] = request.request_metrics.start_context.tag
 
+            # set the tags *after* the servlet completes, in case it decided to
+            # prioritise the span (tags will get dropped on unprioritised spans)
             for k, v in request_tags.items():
                 scope.span.set_tag(k, v)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 496fce2ecc..c3d3daf877 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -46,12 +46,12 @@ from twisted.python.threadpool import ThreadPool
 
 # This module is imported for its side effects; flake8 needn't warn that it's unused.
 import synapse.metrics._reactor_metrics  # noqa: F401
-from synapse.metrics._exposition import (
+from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
+from synapse.metrics._legacy_exposition import (
     MetricsResource,
     generate_latest,
     start_http_server,
 )
-from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
 from synapse.metrics._types import Collector
 from synapse.util import SYNAPSE_VERSION
 
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_legacy_exposition.py
index 353d0a63b6..563d8cc2c6 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_legacy_exposition.py
@@ -34,8 +34,6 @@ from prometheus_client.core import Sample
 from twisted.web.resource import Resource
 from twisted.web.server import Request
 
-from synapse.util import caches
-
 CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
 
 
@@ -80,12 +78,32 @@ def sample_line(line: Sample, name: str) -> str:
     return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
-def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
+# Mapping from new metric names to legacy metric names.
+# We translate these back to their old names when exposing them through our
+# legacy vendored exporter.
+# Only this legacy exposition module applies these name changes.
+LEGACY_METRIC_NAMES = {
+    "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits",
+    "synapse_util_caches_cache_size": "synapse_util_caches_cache:size",
+    "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size",
+    "synapse_util_caches_cache": "synapse_util_caches_cache:total",
+    "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size",
+    "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits",
+    "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size",
+    "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total",
+    "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total",
+    "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count",
+    "synapse_admin_mau_current": "synapse_admin_mau:current",
+    "synapse_admin_mau_max": "synapse_admin_mau:max",
+    "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users",
+}
+
 
-    # Trigger the cache metrics to be rescraped, which updates the common
-    # metrics but do not produce metrics themselves
-    for collector in caches.collectors_by_name.values():
-        collector.collect()
+def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
+    """
+    Generate metrics in legacy format. Modern metrics are generated directly
+    by prometheus-client.
+    """
 
     output = []
 
@@ -94,7 +112,8 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # No samples, don't bother.
             continue
 
-        mname = metric.name
+        # Translate to legacy metric name if it has one.
+        mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name)
         mnewname = metric.name
         mtype = metric.type
 
@@ -124,7 +143,7 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
         om_samples: Dict[str, List[str]] = {}
         for s in metric.samples:
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     # OpenMetrics specific sample, put in a gauge at the end.
                     # (these come from gaugehistograms which don't get renamed,
                     # so no need to faff with mnewname)
@@ -140,12 +159,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             if emit_help:
                 output.append(
                     "# HELP {}{} {}\n".format(
-                        metric.name,
+                        mname,
                         suffix,
                         metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                     )
                 )
-            output.append(f"# TYPE {metric.name}{suffix} gauge\n")
+            output.append(f"# TYPE {mname}{suffix} gauge\n")
             output.extend(lines)
 
         # Get rid of the weird colon things while we're at it
@@ -170,11 +189,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # Get rid of the OpenMetrics specific samples (we should already have
             # dealt with them above anyway.)
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     break
             else:
+                sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name)
                 output.append(
-                    sample_line(s, s.name.replace(":total", "").replace(":", "_"))
+                    sample_line(s, sample_name.replace(":total", "").replace(":", "_"))
                 )
 
     return "".join(output).encode("utf-8")
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index eef3462e10..7a1516d3a8 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -235,7 +235,7 @@ def run_as_background_process(
                         f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
                     )
                 else:
-                    ctx = nullcontext()
+                    ctx = nullcontext()  # type: ignore[assignment]
                 with ctx:
                     return await func(*args, **kwargs)
             except Exception:
diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py
new file mode 100644
index 0000000000..0a22ea3d92
--- /dev/null
+++ b/synapse/metrics/common_usage_metrics.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import attr
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+from prometheus_client import Gauge
+
+# Gauge to expose daily active users metrics
+current_dau_gauge = Gauge(
+    "synapse_admin_daily_active_users",
+    "Current daily active users count",
+)
+
+
+@attr.s(auto_attribs=True)
+class CommonUsageMetrics:
+    """Usage metrics shared between the phone home stats and the prometheus exporter."""
+
+    daily_active_users: int
+
+
+class CommonUsageMetricsManager:
+    """Collects common usage metrics."""
+
+    def __init__(self, hs: "HomeServer") -> None:
+        self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
+
+    async def get_metrics(self) -> CommonUsageMetrics:
+        """Get the CommonUsageMetrics object. If no collection has happened yet, do it
+        before returning the metrics.
+
+        Returns:
+            The CommonUsageMetrics object to read common metrics from.
+        """
+        return await self._collect()
+
+    async def setup(self) -> None:
+        """Keep the gauges for common usage metrics up to date."""
+        await self._update_gauges()
+        self._clock.looping_call(
+            run_as_background_process,
+            5 * 60 * 1000,
+            desc="common_usage_metrics_update_gauges",
+            func=self._update_gauges,
+        )
+
+    async def _collect(self) -> CommonUsageMetrics:
+        """Collect the common metrics and either create the CommonUsageMetrics object to
+        use if it doesn't exist yet, or update it.
+        """
+        dau_count = await self._store.count_daily_users()
+
+        return CommonUsageMetrics(
+            daily_active_users=dau_count,
+        )
+
+    async def _update_gauges(self) -> None:
+        """Update the Prometheus gauges."""
+        metrics = await self._collect()
+
+        current_dau_gauge.set(float(metrics.daily_active_users))
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 6d8bf54083..87ba154cb7 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -929,10 +929,12 @@ class ModuleApi:
         room_id: str,
         new_membership: str,
         content: Optional[JsonDict] = None,
+        remote_room_hosts: Optional[List[str]] = None,
     ) -> EventBase:
         """Updates the membership of a user to the given value.
 
         Added in Synapse v1.46.0.
+        Changed in Synapse v1.65.0: Added the 'remote_room_hosts' parameter.
 
         Args:
             sender: The user performing the membership change. Must be a user local to
@@ -946,6 +948,7 @@ class ModuleApi:
                 https://spec.matrix.org/unstable/client-server-api/#mroommember for the
                 list of allowed values.
             content: Additional values to include in the resulting event's content.
+            remote_room_hosts: Remote servers to use for remote joins/knocks/etc.
 
         Returns:
             The newly created membership event.
@@ -1005,15 +1008,12 @@ class ModuleApi:
             room_id=room_id,
             action=new_membership,
             content=content,
+            remote_room_hosts=remote_room_hosts,
         )
 
         # Try to retrieve the resulting event.
         event = await self._hs.get_datastores().main.get_event(event_id)
 
-        # update_membership is supposed to always return after the event has been
-        # successfully persisted.
-        assert event is not None
-
         return event
 
     async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
@@ -1452,6 +1452,81 @@ class ModuleApi:
             start_timestamp, end_timestamp
         )
 
+    async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]:
+        """
+        Get the room ID associated with a room alias.
+
+        Added in Synapse v1.65.0.
+
+        Args:
+            room_alias: The alias to look up.
+
+        Returns:
+            A tuple of:
+                The room ID (str).
+                Hosts likely to be participating in the room ([str]).
+
+        Raises:
+            SynapseError if room alias is invalid or could not be found.
+        """
+        alias = RoomAlias.from_string(room_alias)
+        (room_id, hosts) = await self._hs.get_room_member_handler().lookup_room_alias(
+            alias
+        )
+
+        return room_id.to_string(), hosts
+
+    async def create_room(
+        self,
+        user_id: str,
+        config: JsonDict,
+        ratelimit: bool = True,
+        creator_join_profile: Optional[JsonDict] = None,
+    ) -> Tuple[str, Optional[str]]:
+        """Creates a new room.
+
+        Added in Synapse v1.65.0.
+
+        Args:
+            user_id:
+                The user who requested the room creation.
+            config : A dict of configuration options. See "Request body" of:
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+            ratelimit: set to False to disable the rate limiter for this specific operation.
+
+            creator_join_profile:
+                Set to override the displayname and avatar for the creating
+                user in this room. If unset, displayname and avatar will be
+                derived from the user's profile. If set, should contain the
+                values to go in the body of the 'join' event (typically
+                `avatar_url` and/or `displayname`.
+
+        Returns:
+                A tuple containing: 1) the room ID (str), 2) if an alias was requested,
+                the room alias (str), otherwise None if no alias was requested.
+
+        Raises:
+            ResourceLimitError if server is blocked to some resource being
+            exceeded.
+            RuntimeError if the user_id does not refer to a local user.
+            SynapseError if the user_id is invalid, room ID couldn't be stored, or
+            something went horribly wrong.
+        """
+        if not self.is_mine(user_id):
+            raise RuntimeError(
+                "Tried to create a room as a user that isn't local to this homeserver",
+            )
+
+        requester = create_requester(user_id)
+        room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
+            requester=requester,
+            config=config,
+            ratelimit=ratelimit,
+            creator_join_profile=creator_join_profile,
+        )
+
+        return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 54b0ec4b97..c42bb8266a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,6 +228,7 @@ class Notifier:
 
         # Called when there are new things to stream over replication
         self.replication_callbacks: List[Callable[[], None]] = []
+        self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
 
         self._federation_client = hs.get_federation_http_client()
 
@@ -280,6 +281,19 @@ class Notifier:
         """
         self.replication_callbacks.append(cb)
 
+    def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
+        """Add a callback that will be called when a user joins a room.
+
+        This only fires on genuine membership changes, e.g. "invite" -> "join".
+        Membership transitions like "join" -> "join" (for e.g. displayname changes) do
+        not trigger the callback.
+
+        When called, the callback receives two arguments: the event ID and the room ID.
+        It should *not* return a Deferred - if it needs to do any asynchronous work, a
+        background thread should be started and wrapped with run_as_background_process.
+        """
+        self._new_join_in_room_callbacks.append(cb)
+
     async def on_new_room_event(
         self,
         event: EventBase,
@@ -723,6 +737,10 @@ class Notifier:
         for cb in self.replication_callbacks:
             cb()
 
+    def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
+        for cb in self._new_join_in_room_callbacks:
+            cb(event_id, room_id)
+
     def notify_remote_server_up(self, server: str) -> None:
         """Notify any replication that a remote server has come back up"""
         # We call federation_sender directly rather than registering as a
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6c0cc5a6ce..440205e80c 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -14,128 +14,235 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
-from typing import Any, Dict, List
-
-from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+"""
+Push rules is the system used to determine which events trigger a push (and a
+bump in notification counts).
+
+This consists of a list of "push rules" for each user, where a push rule is a
+pair of "conditions" and "actions". When a user receives an event Synapse
+iterates over the list of push rules until it finds one where all the conditions
+match the event, at which point "actions" describe the outcome (e.g. notify,
+highlight, etc).
+
+Push rules are split up into 5 different "kinds" (aka "priority classes"), which
+are run in order:
+    1. Override — highest priority rules, e.g. always ignore notices
+    2. Content — content specific rules, e.g. @ notifications
+    3. Room — per room rules, e.g. enable/disable notifications for all messages
+       in a room
+    4. Sender — per sender rules, e.g. never notify for messages from a given
+       user
+    5. Underride — the lowest priority "default" rules, e.g. notify for every
+       message.
+
+The set of "base rules" are the list of rules that every user has by default. A
+user can modify their copy of the push rules in one of three ways:
+
+    1. Adding a new push rule of a certain kind
+    2. Changing the actions of a base rule
+    3. Enabling/disabling a base rule.
+
+The base rules are split into whether they come before or after a particular
+kind, so the order of push rule evaluation would be: base rules for before
+"override" kind, user defined "override" rules, base rules after "override"
+kind, etc, etc.
+"""
+
+import itertools
+import logging
+from typing import Dict, Iterator, List, Mapping, Sequence, Tuple, Union
+
+import attr
+
+from synapse.config.experimental import ExperimentalConfig
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class PushRule:
+    """A push rule
+
+    Attributes:
+        rule_id: a unique ID for this rule
+        priority_class: what "kind" of push rule this is (see
+            `PRIORITY_CLASS_MAP` for mapping between int and kind)
+        conditions: the sequence of conditions that all need to match
+        actions: the actions to apply if all conditions are met
+        default: is this a base rule?
+        default_enabled: is this enabled by default?
+    """
 
+    rule_id: str
+    priority_class: int
+    conditions: Sequence[Mapping[str, str]]
+    actions: Sequence[Union[str, Mapping]]
+    default: bool = False
+    default_enabled: bool = True
 
-def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
-    """Combine the list of rules set by the user with the default push rules
 
-    Args:
-        rawrules: The rules the user has modified or set.
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class PushRules:
+    """A collection of push rules for an account.
 
-    Returns:
-        A new list with the rules set by the user combined with the defaults.
+    Can be iterated over, producing push rules in priority order.
     """
-    ruleslist = []
 
-    # Grab the base rules that the user has modified.
-    # The modified base rules have a priority_class of -1.
-    modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
+    # A mapping from rule ID to push rule that overrides a base rule. These will
+    # be returned instead of the base rule.
+    overriden_base_rules: Dict[str, PushRule] = attr.Factory(dict)
+
+    # The following stores the custom push rules at each priority class.
+    #
+    # We keep these separate (rather than combining into one big list) to avoid
+    # copying the base rules around all the time.
+    override: List[PushRule] = attr.Factory(list)
+    content: List[PushRule] = attr.Factory(list)
+    room: List[PushRule] = attr.Factory(list)
+    sender: List[PushRule] = attr.Factory(list)
+    underride: List[PushRule] = attr.Factory(list)
+
+    def __iter__(self) -> Iterator[PushRule]:
+        # When iterating over the push rules we need to return the base rules
+        # interspersed at the correct spots.
+        for rule in itertools.chain(
+            BASE_PREPEND_OVERRIDE_RULES,
+            self.override,
+            BASE_APPEND_OVERRIDE_RULES,
+            self.content,
+            BASE_APPEND_CONTENT_RULES,
+            self.room,
+            self.sender,
+            self.underride,
+            BASE_APPEND_UNDERRIDE_RULES,
+        ):
+            # Check if a base rule has been overriden by a custom rule. If so
+            # return that instead.
+            override_rule = self.overriden_base_rules.get(rule.rule_id)
+            if override_rule:
+                yield override_rule
+            else:
+                yield rule
+
+    def __len__(self) -> int:
+        # The length is mostly used by caches to get a sense of "size" / amount
+        # of memory this object is using, so we only count the number of custom
+        # rules.
+        return (
+            len(self.overriden_base_rules)
+            + len(self.override)
+            + len(self.content)
+            + len(self.room)
+            + len(self.sender)
+            + len(self.underride)
+        )
 
-    # Remove the modified base rules from the list, They'll be added back
-    # in the default positions in the list.
-    rawrules = [r for r in rawrules if r["priority_class"] >= 0]
 
-    # shove the server default rules for each kind onto the end of each
-    current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
+@attr.s(auto_attribs=True, slots=True, frozen=True, weakref_slot=False)
+class FilteredPushRules:
+    """A wrapper around `PushRules` that filters out disabled experimental push
+    rules, and includes the "enabled" state for each rule when iterated over.
+    """
 
-    ruleslist.extend(
-        make_base_prepend_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-        )
-    )
+    push_rules: PushRules
+    enabled_map: Dict[str, bool]
+    experimental_config: ExperimentalConfig
 
-    for r in rawrules:
-        if r["priority_class"] < current_prio_class:
-            while r["priority_class"] < current_prio_class:
-                ruleslist.extend(
-                    make_base_append_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
-                        modified_base_rules,
-                    )
-                )
-                current_prio_class -= 1
-                if current_prio_class > 0:
-                    ruleslist.extend(
-                        make_base_prepend_rules(
-                            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
-                            modified_base_rules,
-                        )
-                    )
-
-        ruleslist.append(r)
-
-    while current_prio_class > 0:
-        ruleslist.extend(
-            make_base_append_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-            )
-        )
-        current_prio_class -= 1
-        if current_prio_class > 0:
-            ruleslist.extend(
-                make_base_prepend_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
-                )
-            )
+    def __iter__(self) -> Iterator[Tuple[PushRule, bool]]:
+        for rule in self.push_rules:
+            if not _is_experimental_rule_enabled(
+                rule.rule_id, self.experimental_config
+            ):
+                continue
 
-    return ruleslist
+            enabled = self.enabled_map.get(rule.rule_id, rule.default_enabled)
 
+            yield rule, enabled
 
-def make_base_append_rules(
-    kind: str, modified_base_rules: Dict[str, Dict[str, Any]]
-) -> List[Dict[str, Any]]:
-    rules = []
+    def __len__(self) -> int:
+        return len(self.push_rules)
 
-    if kind == "override":
-        rules = BASE_APPEND_OVERRIDE_RULES
-    elif kind == "underride":
-        rules = BASE_APPEND_UNDERRIDE_RULES
-    elif kind == "content":
-        rules = BASE_APPEND_CONTENT_RULES
 
-    # Copy the rules before modifying them
-    rules = copy.deepcopy(rules)
-    for r in rules:
-        # Only modify the actions, keep the conditions the same.
-        assert isinstance(r["rule_id"], str)
-        modified = modified_base_rules.get(r["rule_id"])
-        if modified:
-            r["actions"] = modified["actions"]
+DEFAULT_EMPTY_PUSH_RULES = PushRules()
 
-    return rules
 
+def compile_push_rules(rawrules: List[PushRule]) -> PushRules:
+    """Given a set of custom push rules return a `PushRules` instance (which
+    includes the base rules).
+    """
+
+    if not rawrules:
+        # Fast path to avoid allocating empty lists when there are no custom
+        # rules for the user.
+        return DEFAULT_EMPTY_PUSH_RULES
+
+    rules = PushRules()
 
-def make_base_prepend_rules(
-    kind: str,
-    modified_base_rules: Dict[str, Dict[str, Any]],
-) -> List[Dict[str, Any]]:
-    rules = []
+    for rule in rawrules:
+        # We need to decide which bucket each custom push rule goes into.
 
-    if kind == "override":
-        rules = BASE_PREPEND_OVERRIDE_RULES
+        # If it has the same ID as a base rule then it overrides that...
+        overriden_base_rule = BASE_RULES_BY_ID.get(rule.rule_id)
+        if overriden_base_rule:
+            rules.overriden_base_rules[rule.rule_id] = attr.evolve(
+                overriden_base_rule, actions=rule.actions
+            )
+            continue
+
+        # ... otherwise it gets added to the appropriate priority class bucket
+        collection: List[PushRule]
+        if rule.priority_class == 5:
+            collection = rules.override
+        elif rule.priority_class == 4:
+            collection = rules.content
+        elif rule.priority_class == 3:
+            collection = rules.room
+        elif rule.priority_class == 2:
+            collection = rules.sender
+        elif rule.priority_class == 1:
+            collection = rules.underride
+        elif rule.priority_class <= 0:
+            logger.info(
+                "Got rule with priority class less than zero, but doesn't override a base rule: %s",
+                rule,
+            )
+            continue
+        else:
+            # We log and continue here so as not to break event sending
+            logger.error("Unknown priority class: %", rule.priority_class)
+            continue
 
-    # Copy the rules before modifying them
-    rules = copy.deepcopy(rules)
-    for r in rules:
-        # Only modify the actions, keep the conditions the same.
-        assert isinstance(r["rule_id"], str)
-        modified = modified_base_rules.get(r["rule_id"])
-        if modified:
-            r["actions"] = modified["actions"]
+        collection.append(rule)
 
     return rules
 
 
-# We have to annotate these types, otherwise mypy infers them as
-# `List[Dict[str, Sequence[Collection[str]]]]`.
-BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/content/.m.rule.contains_user_name",
-        "conditions": [
+def _is_experimental_rule_enabled(
+    rule_id: str, experimental_config: ExperimentalConfig
+) -> bool:
+    """Used by `FilteredPushRules` to filter out experimental rules when they
+    have not been enabled.
+    """
+    if (
+        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
+        and not experimental_config.msc3786_enabled
+    ):
+        return False
+    if (
+        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
+        and not experimental_config.msc3772_enabled
+    ):
+        return False
+    return True
+
+
+BASE_APPEND_CONTENT_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["content"],
+        rule_id="global/content/.m.rule.contains_user_name",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.body",
@@ -143,29 +250,33 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
                 "pattern_type": "user_localpart",
             }
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight"},
         ],
-    }
+    )
 ]
 
 
-BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/override/.m.rule.master",
-        "enabled": False,
-        "conditions": [],
-        "actions": ["dont_notify"],
-    }
+BASE_PREPEND_OVERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.master",
+        default_enabled=False,
+        conditions=[],
+        actions=["dont_notify"],
+    )
 ]
 
 
-BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/override/.m.rule.suppress_notices",
-        "conditions": [
+BASE_APPEND_OVERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.suppress_notices",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.msgtype",
@@ -173,13 +284,15 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_suppress_notices",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
     # otherwise invites will be matched by .m.rule.member_event
-    {
-        "rule_id": "global/override/.m.rule.invite_for_me",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.invite_for_me",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -195,21 +308,23 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
             # Match the requester's MXID.
             {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # Will we sometimes want to know about people joining and leaving?
     # Perhaps: if so, this could be expanded upon. Seems the most usual case
     # is that we don't though. We add this override rule so that even if
     # the room rule is set to notify, we don't get notifications about
     # join/leave/avatar/displayname events.
     # See also: https://matrix.org/jira/browse/SYN-607
-    {
-        "rule_id": "global/override/.m.rule.member_event",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.member_event",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -217,24 +332,28 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_member",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # This was changed from underride to override so it's closer in priority
     # to the content rules where the user name highlight rule lives. This
     # way a room rule is lower priority than both but a custom override rule
     # is higher priority than both.
-    {
-        "rule_id": "global/override/.m.rule.contains_display_name",
-        "conditions": [{"kind": "contains_display_name"}],
-        "actions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.contains_display_name",
+        conditions=[{"kind": "contains_display_name"}],
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight"},
         ],
-    },
-    {
-        "rule_id": "global/override/.m.rule.roomnotif",
-        "conditions": [
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.roomnotif",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "content.body",
@@ -247,11 +366,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_roomnotif_pl",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": True}],
-    },
-    {
-        "rule_id": "global/override/.m.rule.tombstone",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": True}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.tombstone",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -265,11 +386,13 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_tombstone_statekey",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": True}],
-    },
-    {
-        "rule_id": "global/override/.m.rule.reaction",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": True}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.m.rule.reaction",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -277,14 +400,16 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_reaction",
             }
         ],
-        "actions": ["dont_notify"],
-    },
+        actions=["dont_notify"],
+    ),
     # XXX: This is an experimental rule that is only enabled if msc3786_enabled
     # is enabled, if it is not the rule gets filtered out in _load_rules() in
     # PushRulesWorkerStore
-    {
-        "rule_id": "global/override/.org.matrix.msc3786.rule.room.server_acl",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["override"],
+        rule_id="global/override/.org.matrix.msc3786.rule.room.server_acl",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -298,15 +423,17 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_room_server_acl_state_key",
             },
         ],
-        "actions": [],
-    },
+        actions=[],
+    ),
 ]
 
 
-BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
-    {
-        "rule_id": "global/underride/.m.rule.call",
-        "conditions": [
+BASE_APPEND_UNDERRIDE_RULES = [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.call",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -314,17 +441,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_call",
             }
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "ring"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # XXX: once m.direct is standardised everywhere, we should use it to detect
     # a DM from the user's perspective rather than this heuristic.
-    {
-        "rule_id": "global/underride/.m.rule.room_one_to_one",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.room_one_to_one",
+        conditions=[
             {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
             {
                 "kind": "event_match",
@@ -333,17 +462,19 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_message",
             },
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
+    ),
     # XXX: this is going to fire for events which aren't m.room.messages
     # but are encrypted (e.g. m.call.*)...
-    {
-        "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.encrypted_room_one_to_one",
+        conditions=[
             {"kind": "room_member_count", "is": "2", "_cache_key": "member_count"},
             {
                 "kind": "event_match",
@@ -352,15 +483,17 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_encrypted",
             },
         ],
-        "actions": [
+        actions=[
             "notify",
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight", "value": False},
         ],
-    },
-    {
-        "rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
-        "conditions": [
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.org.matrix.msc3772.thread_reply",
+        conditions=[
             {
                 "kind": "org.matrix.msc3772.relation_match",
                 "rel_type": "m.thread",
@@ -368,11 +501,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "sender_type": "user_id",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
-    {
-        "rule_id": "global/underride/.m.rule.message",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.message",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -380,13 +515,15 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_message",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
     # XXX: this is going to fire for events which aren't m.room.messages
     # but are encrypted (e.g. m.call.*)...
-    {
-        "rule_id": "global/underride/.m.rule.encrypted",
-        "conditions": [
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.m.rule.encrypted",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -394,11 +531,13 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_encrypted",
             }
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
-    {
-        "rule_id": "global/underride/.im.vector.jitsi",
-        "conditions": [
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
+    PushRule(
+        default=True,
+        priority_class=PRIORITY_CLASS_MAP["underride"],
+        rule_id="global/underride/.im.vector.jitsi",
+        conditions=[
             {
                 "kind": "event_match",
                 "key": "type",
@@ -418,29 +557,27 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
                 "_cache_key": "_is_state_event",
             },
         ],
-        "actions": ["notify", {"set_tweak": "highlight", "value": False}],
-    },
+        actions=["notify", {"set_tweak": "highlight", "value": False}],
+    ),
 ]
 
 
 BASE_RULE_IDS = set()
 
+BASE_RULES_BY_ID: Dict[str, PushRule] = {}
+
 for r in BASE_APPEND_CONTENT_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["content"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["override"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_APPEND_OVERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["override"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
-    r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
-    r["default"] = True
-    BASE_RULE_IDS.add(r["rule_id"])
+    BASE_RULE_IDS.add(r.rule_id)
+    BASE_RULES_BY_ID[r.rule_id] = r
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e581af9a9a..3846fbc5f0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,7 +15,18 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 from prometheus_client import Counter
 
@@ -30,6 +41,7 @@ from synapse.util.caches import register_cache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_event_for_clients_with_state
 
+from .baserules import FilteredPushRules, PushRule
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
 if TYPE_CHECKING:
@@ -112,7 +124,7 @@ class BulkPushRuleEvaluator:
     async def _get_rules_for_event(
         self,
         event: EventBase,
-    ) -> Dict[str, List[Dict[str, Any]]]:
+    ) -> Dict[str, FilteredPushRules]:
         """Get the push rules for all users who may need to be notified about
         the event.
 
@@ -131,6 +143,13 @@ class BulkPushRuleEvaluator:
 
         local_users = await self.store.get_local_users_in_room(event.room_id)
 
+        # Filter out appservice users.
+        local_users = [
+            u
+            for u in local_users
+            if not self.store.get_if_app_services_interested_in_user(u)
+        ]
+
         # if this event is an invite event, we may need to run rules for the user
         # who's been invited, otherwise they won't get told they've been invited
         if event.type == EventTypes.Member and event.membership == Membership.INVITE:
@@ -179,7 +198,7 @@ class BulkPushRuleEvaluator:
         return pl_event.content if pl_event else {}, sender_level
 
     async def _get_mutual_relations(
-        self, event: EventBase, rules: Iterable[Dict[str, Any]]
+        self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]]
     ) -> Dict[str, Set[Tuple[str, str]]]:
         """
         Fetch event metadata for events which related to the same event as the given event.
@@ -187,7 +206,7 @@ class BulkPushRuleEvaluator:
         If the given event has no relation information, returns an empty dictionary.
 
         Args:
-            event_id: The event ID which is targeted by relations.
+            parent_id: The event ID which is targeted by relations.
             rules: The push rules which will be processed for this event.
 
         Returns:
@@ -201,20 +220,13 @@ class BulkPushRuleEvaluator:
         if not self._relations_match_enabled:
             return {}
 
-        # If the event does not have a relation, then cannot have any mutual
-        # relations.
-        relation = relation_from_event(event)
-        if not relation:
-            return {}
-
         # Pre-filter to figure out which relation types are interesting.
         rel_types = set()
-        for rule in rules:
-            # Skip disabled rules.
-            if "enabled" in rule and not rule["enabled"]:
+        for rule, enabled in rules:
+            if not enabled:
                 continue
 
-            for condition in rule["conditions"]:
+            for condition in rule.conditions:
                 if condition["kind"] != "org.matrix.msc3772.relation_match":
                     continue
 
@@ -228,9 +240,7 @@ class BulkPushRuleEvaluator:
             return {}
 
         # If any valid rules were found, fetch the mutual relations.
-        return await self.store.get_mutual_event_relations(
-            relation.parent_id, rel_types
-        )
+        return await self.store.get_mutual_event_relations(parent_id, rel_types)
 
     @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
@@ -244,10 +254,15 @@ class BulkPushRuleEvaluator:
             # This can happen due to out of band memberships
             return
 
-        count_as_unread = _should_count_as_unread(event, context)
+        # Disable counting as unread unless the experimental configuration is
+        # enabled, as it can cause additional (unwanted) rows to be added to the
+        # event_push_actions table.
+        count_as_unread = False
+        if self.hs.config.experimental.msc2654_enabled:
+            count_as_unread = _should_count_as_unread(event, context)
 
         rules_by_user = await self._get_rules_for_event(event)
-        actions_by_user: Dict[str, List[Union[dict, str]]] = {}
+        actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
 
         room_member_count = await self.store.get_number_joined_users_in_room(
             event.room_id
@@ -258,9 +273,17 @@ class BulkPushRuleEvaluator:
             sender_power_level,
         ) = await self._get_power_levels_and_sender_level(event, context)
 
-        relations = await self._get_mutual_relations(
-            event, itertools.chain(*rules_by_user.values())
-        )
+        relation = relation_from_event(event)
+        # If the event does not have a relation, then cannot have any mutual
+        # relations or thread ID.
+        relations = {}
+        thread_id = "main"
+        if relation:
+            relations = await self._get_mutual_relations(
+                relation.parent_id, itertools.chain(*rules_by_user.values())
+            )
+            if relation.rel_type == RelationTypes.THREAD:
+                thread_id = relation.parent_id
 
         evaluator = PushRuleEvaluatorForEvent(
             event,
@@ -310,15 +333,13 @@ class BulkPushRuleEvaluator:
                 # current user, it'll be added to the dict later.
                 actions_by_user[uid] = []
 
-            for rule in rules:
-                if "enabled" in rule and not rule["enabled"]:
+            for rule, enabled in rules:
+                if not enabled:
                     continue
 
-                matches = evaluator.check_conditions(
-                    rule["conditions"], uid, display_name
-                )
+                matches = evaluator.check_conditions(rule.conditions, uid, display_name)
                 if matches:
-                    actions = [x for x in rule["actions"] if x != "dont_notify"]
+                    actions = [x for x in rule.actions if x != "dont_notify"]
                     if actions and "notify" in actions:
                         # Push rules say we should notify the user of this event
                         actions_by_user[uid] = actions
@@ -331,6 +352,7 @@ class BulkPushRuleEvaluator:
             event.event_id,
             actions_by_user,
             count_as_unread,
+            thread_id,
         )
 
 
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 5117ef6854..73618d9234 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -18,16 +18,15 @@ from typing import Any, Dict, List, Optional
 from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
 from synapse.types import UserID
 
+from .baserules import FilteredPushRules, PushRule
+
 
 def format_push_rules_for_user(
-    user: UserID, ruleslist: List
+    user: UserID, ruleslist: FilteredPushRules
 ) -> Dict[str, Dict[str, list]]:
     """Converts a list of rawrules and a enabled map into nested dictionaries
     to match the Matrix client-server format for push rules"""
 
-    # We're going to be mutating this a lot, so do a deep copy
-    ruleslist = copy.deepcopy(ruleslist)
-
     rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
         "global": {},
         "device": {},
@@ -35,11 +34,30 @@ def format_push_rules_for_user(
 
     rules["global"] = _add_empty_priority_class_arrays(rules["global"])
 
-    for r in ruleslist:
-        template_name = _priority_class_to_template_name(r["priority_class"])
+    for r, enabled in ruleslist:
+        template_name = _priority_class_to_template_name(r.priority_class)
+
+        rulearray = rules["global"][template_name]
+
+        template_rule = _rule_to_template(r)
+        if not template_rule:
+            continue
+
+        rulearray.append(template_rule)
+
+        template_rule["enabled"] = enabled
+
+        if "conditions" not in template_rule:
+            # Not all formatted rules have explicit conditions, e.g. "room"
+            # rules omit them as they can be derived from the kind and rule ID.
+            #
+            # If the formatted rule has no conditions then we can skip the
+            # formatting of conditions.
+            continue
 
         # Remove internal stuff.
-        for c in r["conditions"]:
+        template_rule["conditions"] = copy.deepcopy(template_rule["conditions"])
+        for c in template_rule["conditions"]:
             c.pop("_cache_key", None)
 
             pattern_type = c.pop("pattern_type", None)
@@ -52,16 +70,6 @@ def format_push_rules_for_user(
             if sender_type == "user_id":
                 c["sender"] = user.to_string()
 
-        rulearray = rules["global"][template_name]
-
-        template_rule = _rule_to_template(r)
-        if template_rule:
-            if "enabled" in r:
-                template_rule["enabled"] = r["enabled"]
-            else:
-                template_rule["enabled"] = True
-            rulearray.append(template_rule)
-
     return rules
 
 
@@ -71,24 +79,24 @@ def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
     return d
 
 
-def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
-    unscoped_rule_id = None
-    if "rule_id" in rule:
-        unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
+def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
+    templaterule: Dict[str, Any]
+
+    unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
 
-    template_name = _priority_class_to_template_name(rule["priority_class"])
+    template_name = _priority_class_to_template_name(rule.priority_class)
     if template_name in ["override", "underride"]:
-        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+        templaterule = {"conditions": rule.conditions, "actions": rule.actions}
     elif template_name in ["sender", "room"]:
-        templaterule = {"actions": rule["actions"]}
-        unscoped_rule_id = rule["conditions"][0]["pattern"]
+        templaterule = {"actions": rule.actions}
+        unscoped_rule_id = rule.conditions[0]["pattern"]
     elif template_name == "content":
-        if len(rule["conditions"]) != 1:
+        if len(rule.conditions) != 1:
             return None
-        thecond = rule["conditions"][0]
+        thecond = rule.conditions[0]
         if "pattern" not in thecond:
             return None
-        templaterule = {"actions": rule["actions"]}
+        templaterule = {"actions": rule.actions}
         templaterule["pattern"] = thecond["pattern"]
     else:
         # This should not be reached unless this function is not kept in sync
@@ -97,8 +105,8 @@ def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
 
     if unscoped_rule_id:
         templaterule["rule_id"] = unscoped_rule_id
-    if "default" in rule:
-        templaterule["default"] = rule["default"]
+    if rule.default:
+        templaterule["default"] = True
     return templaterule
 
 
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2e8a017add..3c5632cd91 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,7 +15,18 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
+from typing import (
+    Any,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Pattern,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
 
 from matrix_common.regex import glob_to_regex, to_word_pattern
 
@@ -32,14 +43,14 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
 def _room_member_count(
-    ev: EventBase, condition: Dict[str, Any], room_member_count: int
+    ev: EventBase, condition: Mapping[str, Any], room_member_count: int
 ) -> bool:
     return _test_ineq_condition(condition, room_member_count)
 
 
 def _sender_notification_permission(
     ev: EventBase,
-    condition: Dict[str, Any],
+    condition: Mapping[str, Any],
     sender_power_level: int,
     power_levels: Dict[str, Union[int, Dict[str, int]]],
 ) -> bool:
@@ -54,7 +65,7 @@ def _sender_notification_permission(
     return sender_power_level >= room_notif_level
 
 
-def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
+def _test_ineq_condition(condition: Mapping[str, Any], number: int) -> bool:
     if "is" not in condition:
         return False
     m = INEQUALITY_EXPR.match(condition["is"])
@@ -137,7 +148,7 @@ class PushRuleEvaluatorForEvent:
         self._condition_cache: Dict[str, bool] = {}
 
     def check_conditions(
-        self, conditions: List[dict], uid: str, display_name: Optional[str]
+        self, conditions: Sequence[Mapping], uid: str, display_name: Optional[str]
     ) -> bool:
         """
         Returns true if a user's conditions/user ID/display name match the event.
@@ -169,7 +180,7 @@ class PushRuleEvaluatorForEvent:
         return True
 
     def matches(
-        self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
+        self, condition: Mapping[str, Any], user_id: str, display_name: Optional[str]
     ) -> bool:
         """
         Returns true if a user's condition/user ID/display name match the event.
@@ -204,7 +215,7 @@ class PushRuleEvaluatorForEvent:
             #     endpoint with an unknown kind, see _rule_tuple_from_request_object.
             return True
 
-    def _event_match(self, condition: dict, user_id: str) -> bool:
+    def _event_match(self, condition: Mapping, user_id: str) -> bool:
         """
         Check an "event_match" push rule condition.
 
@@ -269,7 +280,7 @@ class PushRuleEvaluatorForEvent:
 
         return bool(r.search(body))
 
-    def _relation_match(self, condition: dict, user_id: str) -> bool:
+    def _relation_match(self, condition: Mapping, user_id: str) -> bool:
         """
         Check an "relation_match" push rule condition.
 
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6661887d9f..658bf373b7 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,6 +17,7 @@ from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
+from synapse.util.async_helpers import concurrently_execute
 
 
 async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -25,13 +26,19 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
 
     badge = len(invites)
 
-    for room_id in joins:
-        notifs = await (
-            store.get_unread_event_push_actions_by_room_for_user(
+    room_notifs = []
+
+    async def get_room_unread_count(room_id: str) -> None:
+        room_notifs.append(
+            await store.get_unread_event_push_actions_by_room_for_user(
                 room_id,
                 user_id,
             )
         )
+
+    await concurrently_execute(get_room_unread_count, joins, 10)
+
+    for notifs in room_notifs:
         if notifs.notify_count == 0:
             continue
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d0cc657b44..1e0ef44fc7 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -328,7 +328,7 @@ class PusherPool:
             return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusher_config)
+            pusher = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
@@ -346,23 +346,28 @@ class PusherPool:
             )
             return None
 
-        if not p:
+        if not pusher:
             return None
 
-        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
+        appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
 
-        byuser = self.pushers.setdefault(pusher_config.user_name, {})
+        byuser = self.pushers.setdefault(pusher.user_id, {})
         if appid_pushkey in byuser:
-            byuser[appid_pushkey].on_stop()
-        byuser[appid_pushkey] = p
+            previous_pusher = byuser[appid_pushkey]
+            previous_pusher.on_stop()
 
-        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
+            synapse_pushers.labels(
+                type(previous_pusher).__name__, previous_pusher.app_id
+            ).dec()
+        byuser[appid_pushkey] = pusher
+
+        synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusher_config.user_name
-        last_stream_ordering = pusher_config.last_stream_ordering
+        user_id = pusher.user_id
+        last_stream_ordering = pusher.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -372,9 +377,9 @@ class PusherPool:
             # risk missing push.
             have_notifs = True
 
-        p.on_started(have_notifs)
+        pusher.on_started(have_notifs)
 
-        return p
+        return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index a4ae4040c3..acb0bd18f7 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -26,12 +26,13 @@ from twisted.web.server import Request
 
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
-from synapse.http.server import HttpServer, is_method_cancellable
+from synapse.http.server import HttpServer
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.cancellation import is_function_cancellable
 from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
@@ -196,7 +197,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 "ascii"
             )
 
-        @trace(opname="outgoing_replication_request")
+        @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
@@ -311,7 +312,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         url_args = list(self.PATH_ARGS)
         method = self.METHOD
 
-        if self.CACHE and is_method_cancellable(self._handle_request):
+        if self.CACHE and is_function_cancellable(self._handle_request):
             raise Exception(
                 f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
                 "is set. The cancellable flag would have no effect."
@@ -359,6 +360,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         # The `@cancellable` decorator may be applied to `_handle_request`. But we
         # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
         # so we have to set up the cancellable flag ourselves.
-        request.is_render_cancellable = is_method_cancellable(self._handle_request)
+        request.is_render_cancellable = is_function_cancellable(self._handle_request)
 
         return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
deleted file mode 100644
index 7644146dba..0000000000
--- a/synapse/replication/slave/storage/_base.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Optional
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-        if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen: Optional[
-                MultiWriterIdGenerator
-            ] = MultiWriterIdGenerator(
-                db_conn,
-                database,
-                stream_name="caches",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    (
-                        "cache_invalidation_stream_by_instance",
-                        "instance_name",
-                        "stream_id",
-                    )
-                ],
-                sequence_name="cache_invalidation_stream_seq",
-                writers=[],
-            )
-        else:
-            self._cache_id_gen = None
-
-        self.hs = hs
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
deleted file mode 100644
index ee74ee7d85..0000000000
--- a/synapse/replication/slave/storage/account_data.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.account_data import AccountDataWorkerStore
-from synapse.storage.databases.main.tags import TagsWorkerStore
-
-
-class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
deleted file mode 100644
index 29f50c0add..0000000000
--- a/synapse/replication/slave/storage/appservice.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.appservice import (
-    ApplicationServiceTransactionWorkerStore,
-    ApplicationServiceWorkerStore,
-)
-
-
-class SlavedApplicationServiceStore(
-    ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
-):
-    pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
deleted file mode 100644
index e940751084..0000000000
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-
-
-class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index a48cc02069..6fcade510a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -14,7 +14,6 @@
 
 from typing import TYPE_CHECKING, Any, Iterable
 
-from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@@ -24,7 +23,7 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
+class SlavedDeviceStore(DeviceWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
deleted file mode 100644
index 71fde0c96c..0000000000
--- a/synapse/replication/slave/storage/directory.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.directory import DirectoryWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a72dad7464..fe47778cb1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -29,8 +29,6 @@ from synapse.storage.databases.main.stream import StreamWorkerStore
 from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -56,7 +54,6 @@ class SlavedEventStore(
     EventsWorkerStore,
     UserErasureWorkerStore,
     RelationsWorkerStore,
-    BaseSlavedStore,
 ):
     def __init__(
         self,
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 4d185e2b56..c52679cd60 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,16 +14,15 @@
 
 from typing import TYPE_CHECKING
 
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.filtering import FilteringStore
 
-from ._base import BaseSlavedStore
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedFilteringStore(BaseSlavedStore):
+class SlavedFilteringStore(SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
deleted file mode 100644
index 99f4a22642..0000000000
--- a/synapse/replication/slave/storage/profile.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.databases.main.profile import ProfileWorkerStore
-
-
-class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 52ee3f7e58..5e65eaf1e0 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -31,6 +31,5 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index de642bba71..44ed20e424 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -18,14 +18,13 @@ from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.pusher import PusherWorkerStore
 
-from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
deleted file mode 100644
index 3826b87dec..0000000000
--- a/synapse/replication/slave/storage/receipts.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
deleted file mode 100644
index 5dae35a960..0000000000
--- a/synapse/replication/slave/storage/registration.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.registration import RegistrationWorkerStore
-
-from ._base import BaseSlavedStore
-
-
-class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
-    pass
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
 from synapse.federation import send_queue
 from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+                # If this event is a join, make a note of it so we have an accurate
+                # cross-worker room rate limit.
+                # TODO: Erik said we should exclude rows that came from ex_outliers
+                #  here, but I don't see how we can determine that. I guess we could
+                #  add a flag to row.data?
+                if (
+                    row.data.type == EventTypes.Member
+                    and row.data.membership == Membership.JOIN
+                    and not row.data.outlier
+                ):
+                    # TODO retrieve the previous state, and exclude join -> join transitions
+                    self.notifier.notify_user_joined_room(
+                        row.data.event_id, row.data.room_id
+                    )
+
         await self._presence_handler.process_replication_rows(
             stream_name, instance_name, token, rows
         )
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e1cbfa50eb..0f166d16aa 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -35,7 +35,6 @@ from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
@@ -332,46 +331,31 @@ class ReplicationCommandHandler:
 
     def start_replication(self, hs: "HomeServer") -> None:
         """Helper method to start replication."""
-        if hs.config.redis.redis_enabled:
-            from synapse.replication.tcp.redis import (
-                RedisDirectTcpReplicationClientFactory,
-            )
+        from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory
 
-            # First let's ensure that we have a ReplicationStreamer started.
-            hs.get_replication_streamer()
+        # First let's ensure that we have a ReplicationStreamer started.
+        hs.get_replication_streamer()
 
-            # We need two connections to redis, one for the subscription stream and
-            # one to send commands to (as you can't send further redis commands to a
-            # connection after SUBSCRIBE is called).
+        # We need two connections to redis, one for the subscription stream and
+        # one to send commands to (as you can't send further redis commands to a
+        # connection after SUBSCRIBE is called).
 
-            # First create the connection for sending commands.
-            outbound_redis_connection = hs.get_outbound_redis_connection()
+        # First create the connection for sending commands.
+        outbound_redis_connection = hs.get_outbound_redis_connection()
 
-            # Now create the factory/connection for the subscription stream.
-            self._factory = RedisDirectTcpReplicationClientFactory(
-                hs,
-                outbound_redis_connection,
-                channel_names=self._channels_to_subscribe_to,
-            )
-            hs.get_reactor().connectTCP(
-                hs.config.redis.redis_host,
-                hs.config.redis.redis_port,
-                self._factory,
-                timeout=30,
-                bindAddress=None,
-            )
-        else:
-            client_name = hs.get_instance_name()
-            self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
-            host = hs.config.worker.worker_replication_host
-            port = hs.config.worker.worker_replication_port
-            hs.get_reactor().connectTCP(
-                host,
-                port,
-                self._factory,
-                timeout=30,
-                bindAddress=None,
-            )
+        # Now create the factory/connection for the subscription stream.
+        self._factory = RedisDirectTcpReplicationClientFactory(
+            hs,
+            outbound_redis_connection,
+            channel_names=self._channels_to_subscribe_to,
+        )
+        hs.get_reactor().connectTCP(
+            hs.config.redis.redis_host,
+            hs.config.redis.redis_port,
+            self._factory,
+            timeout=30,
+            bindAddress=None,
+        )
 
     def get_streams(self) -> Dict[str, Stream]:
         """Get a map from stream name to all streams."""
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
     relates_to: Optional[str]
     membership: Optional[str]
     rejected: bool
+    outlier: bool
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html
index b751359bdf..bd4f7cea97 100644
--- a/synapse/res/templates/account_previously_renewed.html
+++ b/synapse/res/templates/account_previously_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+    Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index e8c0f52f05..57b319f375 100644
--- a/synapse/res/templates/account_renewed.html
+++ b/synapse/res/templates/account_renewed.html
@@ -1 +1,12 @@
-<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</title>
+</head>
+<body>
+    Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
index cc4ab07e09..71f2215b7a 100644
--- a/synapse/res/templates/add_threepid.html
+++ b/synapse/res/templates/add_threepid.html
@@ -1,9 +1,14 @@
-<html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Request to add an email address to your Matrix account</title>
+</head>
 <body>
     <p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p>
-
     <a href="{{ link }}">{{ link }}</a>
-
     <p>If this was not you, you can safely ignore this email. Thank you.</p>
 </body>
 </html>
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
index 441d11c846..bd627ee9ce 100644
--- a/synapse/res/templates/add_threepid_failure.html
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -1,8 +1,13 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Request failed</title>
+</head>
 <body>
-<p>The request failed for the following reason: {{ failure_reason }}.</p>
-
-<p>No changes have been made to your account.</p>
+    <p>The request failed for the following reason: {{ failure_reason }}.</p>
+    <p>No changes have been made to your account.</p>
 </body>
 </html>
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
index fbd6e4018f..49170c138e 100644
--- a/synapse/res/templates/add_threepid_success.html
+++ b/synapse/res/templates/add_threepid_success.html
@@ -1,6 +1,12 @@
-<html>
-<head></head>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Your email has now been validated</title>
+</head>
 <body>
-<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+    <p>Your email has now been validated, please return to your client. You may now close this window.</p>
 </body>
-</html>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html
index baf4633142..2d6ac44a0e 100644
--- a/synapse/res/templates/auth_success.html
+++ b/synapse/res/templates/auth_success.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 <script>
 if (window.onAuthDone) {
diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html
index 6bd2b98364..2c7c384fe3 100644
--- a/synapse/res/templates/invalid_token.html
+++ b/synapse/res/templates/invalid_token.html
@@ -1 +1,12 @@
-<html><body>Invalid renewal token.</body><html>
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>Invalid renewal token.</title>
+</head>
+<body>
+    Invalid renewal token.
+</body>
+</html>
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index d87311f659..865f9f7ada 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -1,6 +1,8 @@
 <!doctype html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include 'mail.css' without context %}
             {% include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index 27d4182790..9dba0c0253 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -1,6 +1,8 @@
 <!doctype html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {%- include 'mail.css' without context %}
             {%- include "mail-%s.css" % app_name ignore missing without context %}
diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html
index a197bf872c..a8bdce357b 100644
--- a/synapse/res/templates/password_reset.html
+++ b/synapse/res/templates/password_reset.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+    <head>
+        <title>Password reset</title>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    </head>
 <body>
     <p>A password reset request has been received for your Matrix account. If this was you, please click the link below to confirm resetting your password:</p>
 
diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html
index def4b5162b..2e3fd2ec1e 100644
--- a/synapse/res/templates/password_reset_confirmation.html
+++ b/synapse/res/templates/password_reset_confirmation.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Password reset confirmation</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <!--Use a hidden form to resubmit the information necessary to reset the password-->
 <form method="post">
diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html
index 9e3c4446e3..2d59c463f0 100644
--- a/synapse/res/templates/password_reset_failure.html
+++ b/synapse/res/templates/password_reset_failure.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Password reset failure</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>The request failed for the following reason: {{ failure_reason }}.</p>
 
diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html
index 7324d66d1e..5165bd1fa2 100644
--- a/synapse/res/templates/password_reset_success.html
+++ b/synapse/res/templates/password_reset_success.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Your email has now been validated, please return to your client to reset your password. You may now close this window.</p>
 </body>
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index b3db06ef97..615d3239c6 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <script src="https://www.recaptcha.net/recaptcha/api.js"
     async defer></script>
 <script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html
index 16730a527f..20e831ff4a 100644
--- a/synapse/res/templates/registration.html
+++ b/synapse/res/templates/registration.html
@@ -1,4 +1,9 @@
-<html>
+<html lang="en">
+<head>
+    <title>Registration</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
     <p>You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:</p>
 
diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html
index 2833d79c37..a6ed22bc90 100644
--- a/synapse/res/templates/registration_failure.html
+++ b/synapse/res/templates/registration_failure.html
@@ -1,5 +1,8 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Validation failed for the following reason: {{ failure_reason }}.</p>
 </body>
diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html
index fbd6e4018f..d51d5549d8 100644
--- a/synapse/res/templates/registration_success.html
+++ b/synapse/res/templates/registration_success.html
@@ -1,5 +1,9 @@
-<html>
-<head></head>
+<html lang="en">
+<head>
+    <title>Your email has now been validated</title>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+</head>
 <body>
 <p>Your email has now been validated, please return to your client. You may now close this window.</p>
 </body>
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
index 4577ce1702..59a98f564c 100644
--- a/synapse/res/templates/registration_token.html
+++ b/synapse/res/templates/registration_token.html
@@ -1,8 +1,8 @@
-<html>
+<html lang="en">
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 </head>
 <body>
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index c3e4deed93..075f801cec 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -3,8 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>SSO account deactivated</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
-        <style type="text/css">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">        <style type="text/css">
             {% include "sso.css" without context %}
         </style>
     </head>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 1ba850369a..2d1db386e1 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -3,7 +3,8 @@
   <head>
     <title>Create your account</title>
     <meta charset="utf-8">
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <script type="text/javascript">
       let wasKeyboard = false;
       document.addEventListener("mousedown", function() { wasKeyboard = false; });
@@ -138,7 +139,7 @@
         <div class="username_input" id="username_input">
           <label for="field-username">Username (required)</label>
           <div class="prefix">@</div>
-          <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
+          <input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus autocorrect="off" autocapitalize="none">
           <div class="postfix">:{{ server_name }}</div>
         </div>
         <output for="username_input" id="field-username-output"></output>
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index da579ffe69..94403fc3ce 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication failed</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index f9d0456f0a..aa1c974a6b 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Confirm it's you</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 1ed3967e87..4898af6011 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication successful</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
         </style>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 472309c350..19992ff2ad 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -3,7 +3,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Authentication failed</title>
-        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <style type="text/css">
             {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 53b82db84e..56fabfa3d2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -1,6 +1,8 @@
 <!DOCTYPE html>
 <html lang="en">
     <head>
+        <meta http-equiv="X-UA-Compatible" content="IE=edge">
+        <meta name="viewport" content="width=device-width, initial-scale=1.0">
         <meta charset="UTF-8">
         <title>Choose identity provider</title>
         <style type="text/css">
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
index 68c8b9f33a..523f64c4fc 100644
--- a/synapse/res/templates/sso_new_user_consent.html
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -3,7 +3,8 @@
 <head>
     <meta charset="UTF-8">
     <title>Agree to terms and conditions</title>
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <style type="text/css">
       {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 1b01471ac8..1049a9bd92 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -3,7 +3,8 @@
 <head>
     <meta charset="UTF-8">
     <title>Continue to your account</title>
-    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <style type="text/css">
       {% include "sso.css" without context %}
 
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
index 369ff446d2..2081d990ab 100644
--- a/synapse/res/templates/terms.html
+++ b/synapse/res/templates/terms.html
@@ -1,8 +1,8 @@
 <html>
 <head>
 <title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
 </head>
 <body>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index fa3266720b..bac754e1b1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -61,9 +61,11 @@ from synapse.rest.admin.rooms import (
     MakeRoomAdminRestServlet,
     RoomEventContextServlet,
     RoomMembersRestServlet,
+    RoomMessagesRestServlet,
     RoomRestServlet,
     RoomRestV2Servlet,
     RoomStateRestServlet,
+    RoomTimestampToEventRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@@ -271,6 +273,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     DestinationResetConnectionRestServlet(hs).register(http_server)
     DestinationRestServlet(hs).register(http_server)
     ListDestinationsRestServlet(hs).register(http_server)
+    RoomMessagesRestServlet(hs).register(http_server)
+    RoomTimestampToEventRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if hs.config.worker.worker_app is None:
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 399b205aaf..b467a61dfb 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -19,7 +19,7 @@ from typing import Iterable, Pattern
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
 from synapse.http.site import SynapseRequest
-from synapse.types import UserID
+from synapse.types import Requester
 
 
 def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
@@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None
         AuthError if the requester is not a server admin
     """
     requester = await auth.get_user_by_req(request)
-    await assert_user_is_admin(auth, requester.user)
+    await assert_user_is_admin(auth, requester)
 
 
-async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, requester: Requester) -> None:
     """Verify that the given user is an admin user
 
     Args:
         auth: Auth singleton
-        user_id: user to check
+        requester: The user making the request, according to the access token.
 
     Raises:
         AuthError if the user is not a server admin
     """
-    is_admin = await auth.is_server_admin(user_id)
+    is_admin = await auth.is_server_admin(requester)
     if not is_admin:
         raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 19d4a008e8..73470f09ae 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining room: %s", room_id)
 
@@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining media by user: %s", user_id)
 
@@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet):
         self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         logging.info("Quarantining media by ID: %s/%s", server_name, media_id)
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 9d953d58de..747e6fda83 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
 )
 from synapse.storage.databases.main.room import RoomSortOrder
 from synapse.storage.state import StateFilter
+from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, RoomID, UserID, create_requester
 from synapse.util import json_decoder
 
@@ -75,7 +76,7 @@ class RoomRestV2Servlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
 
         requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_user_is_admin(self._auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -303,6 +304,7 @@ class RoomRestServlet(RestServlet):
 
         members = await self.store.get_users_in_room(room_id)
         ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+        ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
 
         return HTTPStatus.OK, ret
 
@@ -326,7 +328,7 @@ class RoomRestServlet(RestServlet):
         pagination_handler: "PaginationHandler",
     ) -> Tuple[int, JsonDict]:
         requester = await auth.get_user_by_req(request)
-        await assert_user_is_admin(auth, requester.user)
+        await assert_user_is_admin(auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -460,7 +462,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         assert request.args is not None
 
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -550,7 +552,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         content = parse_json_object_from_request(request, allow_empty_body=True)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
@@ -741,7 +743,7 @@ class RoomEventContextServlet(RestServlet):
         self, request: SynapseRequest, room_id: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         limit = parse_integer(request, "limit", default=10)
 
@@ -833,7 +835,7 @@ class BlockRoomRestServlet(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_user_is_admin(self._auth, requester)
 
         content = parse_json_object_from_request(request)
 
@@ -857,3 +859,106 @@ class BlockRoomRestServlet(RestServlet):
             await self._store.unblock_room(room_id)
 
         return HTTPStatus.OK, {"block": block}
+
+
+class RoomMessagesRestServlet(RestServlet):
+    """
+    Get messages list of a room.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+
+    def __init__(self, hs: "HomeServer"):
+        self._hs = hs
+        self._clock = hs.get_clock()
+        self._pagination_handler = hs.get_pagination_handler()
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastores().main
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester)
+
+        pagination_config = await PaginationConfig.from_request(
+            self._store, request, default_limit=10
+        )
+        # Twisted will have processed the args by now.
+        assert request.args is not None
+        as_client_event = b"raw" not in request.args
+        filter_str = parse_string(request, "filter", encoding="utf-8")
+        if filter_str:
+            filter_json = urlparse.unquote(filter_str)
+            event_filter: Optional[Filter] = Filter(
+                self._hs, json_decoder.decode(filter_json)
+            )
+            if (
+                event_filter
+                and event_filter.filter_json.get("event_format", "client")
+                == "federation"
+            ):
+                as_client_event = False
+        else:
+            event_filter = None
+
+        msgs = await self._pagination_handler.get_messages(
+            room_id=room_id,
+            requester=requester,
+            pagin_config=pagination_config,
+            as_client_event=as_client_event,
+            event_filter=event_filter,
+            use_admin_priviledge=True,
+        )
+
+        return HTTPStatus.OK, msgs
+
+
+class RoomTimestampToEventRestServlet(RestServlet):
+    """
+    API endpoint to fetch the `event_id` of the closest event to the given
+    timestamp (`ts` query parameter) in the given direction (`dir` query
+    parameter).
+
+    Useful for cases like jump to date so you can start paginating messages from
+    a given date in the archive.
+
+    `ts` is a timestamp in milliseconds where we will find the closest event in
+    the given direction.
+
+    `dir` can be `f` or `b` to indicate forwards and backwards in time from the
+    given timestamp.
+
+    GET /_synapse/admin/v1/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>
+    {
+        "event_id": ...
+    }
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/timestamp_to_event$")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastores().main
+        self._timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await assert_user_is_admin(self._auth, requester)
+
+        timestamp = parse_integer(request, "ts", required=True)
+        direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+
+        (
+            event_id,
+            origin_server_ts,
+        ) = await self._timestamp_lookup_handler.get_event_for_timestamp(
+            requester, room_id, timestamp, direction
+        )
+
+        return HTTPStatus.OK, {
+            "event_id": event_id,
+            "origin_server_ts": origin_server_ts,
+        }
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f0614a2897..78ee9b6532 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         target_user = UserID.from_string(user_id)
         body = parse_json_object_from_request(request)
@@ -373,6 +373,7 @@ class UserRestServletV2(RestServlet):
                     if (
                         self.hs.config.email.email_enable_notifs
                         and self.hs.config.email.email_notif_for_new_users
+                        and medium == "email"
                     ):
                         await self.pusher_pool.add_pusher(
                             user_id=user_id,
@@ -574,10 +575,9 @@ class WhoisRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         target_user = UserID.from_string(user_id)
         requester = await self.auth.get_user_by_req(request)
-        auth_user = requester.user
 
-        if target_user != auth_user:
-            await assert_user_is_admin(self.auth, auth_user)
+        if target_user != requester.user:
+            await assert_user_is_admin(self.auth, requester)
 
         if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
@@ -600,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self, request: SynapseRequest, target_user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         if not self.is_mine(UserID.from_string(target_user_id)):
             raise SynapseError(
@@ -692,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet):
         This needs user to have administrator access in Synapse.
         """
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
 
         UserID.from_string(target_user_id)
 
@@ -806,7 +806,7 @@ class UserAdminServlet(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         auth_user = requester.user
 
         target_user = UserID.from_string(user_id)
@@ -920,7 +920,7 @@ class UserTokenRestServlet(RestServlet):
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_user_is_admin(self.auth, requester)
         auth_user = requester.user
 
         if not self.is_mine_id(user_id):
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index bdc4a9c068..a09aaf3448 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 import logging
 import random
-from http import HTTPStatus
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 from urllib.parse import urlparse
 
+from pydantic import StrictBool, StrictStr, constr
+
 from twisted.web.server import Request
 
 from synapse.api.constants import LoginType
@@ -28,18 +29,24 @@ from synapse.api.errors import (
     SynapseError,
     ThreepidValidationError,
 )
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_and_validate_json_object_from_request,
     parse_json_object_from_request,
     parse_string,
 )
 from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.rest.client.models import (
+    AuthenticationData,
+    EmailRequestTokenBody,
+    MsisdnRequestTokenBody,
+)
+from synapse.rest.models import RequestBodyModel
 from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -64,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         self.config = hs.config
         self.identity_handler = hs.get_identity_handler()
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -73,41 +80,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User password resets have been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User password resets have been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based password resets have been disabled on this server"
             )
 
-        body = parse_json_object_from_request(request)
-
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-
-        # Extract params from body
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database. This allows the user to reset his password without having to
-        # know the exact spelling (eg. upper and lower case) of address in the database.
-        # Stored in the database "foo@bar.com"
-        # User requests with "FOO@bar.com" would raise a Not Found error
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
         # The email will be sent to the stored address.
@@ -115,7 +105,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         # an email address which is controlled by the attacker but which, after
         # canonicalisation, matches the one in our database.
         existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
-            "email", email
+            "email", body.email
         )
 
         if existing_user_id is None:
@@ -129,35 +119,20 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send password reset emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_password_reset_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
-
+        # Send password reset emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            body.email,
+            body.client_secret,
+            body.send_attempt,
+            self.mailer.send_password_reset_mail,
+            body.next_link,
+        )
         threepid_send_requests.labels(type="email", reason="password_reset").observe(
-            send_attempt
+            body.send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class PasswordRestServlet(RestServlet):
@@ -172,16 +147,23 @@ class PasswordRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self._set_password_handler = hs.get_set_password_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        logout_devices: StrictBool = True
+        if TYPE_CHECKING:
+            # workaround for https://github.com/samuelcolvin/pydantic/issues/156
+            new_password: Optional[StrictStr] = None
+        else:
+            new_password: Optional[constr(max_length=512, strict=True)] = None
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the new password provided to us.
-        new_password = body.pop("new_password", None)
+        new_password = body.new_password
         if new_password is not None:
-            if not isinstance(new_password, str) or len(new_password) > 512:
-                raise SynapseError(400, "Invalid password")
             self.password_policy_handler.validate_password(new_password)
 
         # there are two possibilities here. Either the user does not have an
@@ -201,7 +183,7 @@ class PasswordRestServlet(RestServlet):
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
                     requester,
                     request,
-                    body,
+                    body.dict(exclude_unset=True),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -224,7 +206,7 @@ class PasswordRestServlet(RestServlet):
                 result, params, session_id = await self.auth_handler.check_ui_auth(
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
-                    body,
+                    body.dict(exclude_unset=True),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -299,37 +281,33 @@ class DeactivateAccountRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        id_server: Optional[StrictStr] = None
+        # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297
+        erase: StrictBool = False
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
-        erase = body.get("erase", False)
-        if not isinstance(erase, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'erase' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         requester = await self.auth.get_user_by_req(request)
 
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase, requester
+                requester.user.to_string(), body.erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(exclude_unset=True),
             "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(),
-            erase,
-            requester,
-            id_server=body.get("id_server"),
+            requester.user.to_string(), body.erase, requester, id_server=body.id_server
         )
         if result:
             id_server_unbind_result = "success"
@@ -349,7 +327,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.store = self.hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -358,34 +336,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
+            )
             raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+                400,
+                "Adding an email to your account is disabled on this server",
             )
 
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database.
-        # This ensures that the validation email is sent to the canonicalised address
-        # as it will later be entered into the database.
-        # Otherwise the email will be sent to "FOO@bar.com" and stored as
-        # "foo@bar.com" in database.
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if not await check_3pid_allowed(self.hs, "email", email):
+        if not await check_3pid_allowed(self.hs, "email", body.email):
             raise SynapseError(
                 403,
                 "Your email domain is not authorized on this server",
@@ -393,14 +357,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
-        existing_user_id = await self.store.get_user_id_by_threepid("email", email)
+        existing_user_id = await self.store.get_user_id_by_threepid("email", body.email)
 
         if existing_user_id is not None:
             if self.config.server.request_token_inhibit_3pid_errors:
@@ -413,35 +377,21 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send threepid validation emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_add_threepid_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send threepid validation emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            body.email,
+            body.client_secret,
+            body.send_attempt,
+            self.mailer.send_add_threepid_mail,
+            body.next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="add_threepid").observe(
-            send_attempt
+            body.send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -454,23 +404,16 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(
-            body, ["client_secret", "country", "phone_number", "send_attempt"]
+        body = parse_and_validate_json_object_from_request(
+            request, MsisdnRequestTokenBody
         )
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        country = body["country"]
-        phone_number = body["phone_number"]
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
-
-        msisdn = phone_number_to_msisdn(country, phone_number)
+        msisdn = phone_number_to_msisdn(body.country, body.phone_number)
 
         if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
             raise SynapseError(
                 403,
+                # TODO: is this error message accurate? Looks like we've only rejected
+                #       this phone number, not necessarily all phone numbers
                 "Account phone numbers are not authorized on this server",
                 Codes.THREEPID_DENIED,
             )
@@ -479,9 +422,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
             request, "msisdn", msisdn
         )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
         existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
 
@@ -508,15 +451,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
 
         ret = await self.identity_handler.requestMsisdnToken(
             self.hs.config.registration.account_threepid_delegate_msisdn,
-            country,
-            phone_number,
-            client_secret,
-            send_attempt,
-            next_link,
+            body.country,
+            body.phone_number,
+            body.client_secret,
+            body.send_attempt,
+            body.next_link,
         )
 
         threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
-            send_attempt
+            body.send_attempt
         )
 
         return 200, ret
@@ -534,25 +477,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         self.config = hs.config
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_add_threepid_template_failure_html
             )
 
     async def on_GET(self, request: Request) -> None:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
-            raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
             )
-        elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
             raise SynapseError(
-                400,
-                "This homeserver is not validating threepids. Use an identity server "
-                "instead.",
+                400, "Adding an email to your account is disabled on this server"
             )
 
         sid = parse_string(request, "sid", required=True)
@@ -743,10 +679,12 @@ class ThreepidBindRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
-        assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+        assert_params_in_dict(
+            body, ["id_server", "sid", "id_access_token", "client_secret"]
+        )
         id_server = body["id_server"]
         sid = body["sid"]
-        id_access_token = body.get("id_access_token")  # optional
+        id_access_token = body["id_access_token"]
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
@@ -904,17 +842,18 @@ class AccountStatusRestServlet(RestServlet):
         self._auth = hs.get_auth()
         self._account_handler = hs.get_account_handler()
 
+    class PostBody(RequestBodyModel):
+        # TODO: we could validate that each user id is an mxid here, and/or parse it
+        #       as a UserID
+        user_ids: List[StrictStr]
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await self._auth.get_user_by_req(request)
 
-        body = parse_json_object_from_request(request)
-        if "user_ids" not in body:
-            raise SynapseError(
-                400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
-            )
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         statuses, failures = await self._account_handler.get_account_statuses(
-            body["user_ids"],
+            body.user_ids,
             allow_remote=True,
         )
 
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 6fab102437..ed6ce78d47 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         devices = await self.device_handler.get_devices_by_user(
             requester.user.to_string()
         )
+
+        # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+        # removed from each device. If it is enabled, then the field name will
+        # be replaced by the unstable identifier.
+        #
+        # When MSC3852 is accepted, this block of code can just be removed to
+        # expose "last_seen_user_agent" to clients.
+        for device in devices:
+            last_seen_user_agent = device["last_seen_user_agent"]
+            del device["last_seen_user_agent"]
+            if self._msc3852_enabled:
+                device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
         return 200, {"devices": devices}
 
 
@@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     async def on_GET(
         self, request: SynapseRequest, device_id: str
@@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet):
         )
         if device is None:
             raise NotFoundError("No device found")
+
+        # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+        # removed from each device. If it is enabled, then the field name will
+        # be replaced by the unstable identifier.
+        #
+        # When MSC3852 is accepted, this block of code can just be removed to
+        # expose "last_seen_user_agent" to clients.
+        last_seen_user_agent = device["last_seen_user_agent"]
+        del device["last_seen_user_agent"]
+        if self._msc3852_enabled:
+            device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
         return 200, device
 
     @interactive_auth_handler
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ce806e3c11..f653d2a3e1 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,10 +26,10 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import log_kv, set_tag
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.types import JsonDict, StreamToken
-
-from ._base import client_patterns, interactive_auth_handler
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,7 +71,6 @@ class KeyUploadServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = hs.get_device_handler()
 
-    @trace(opname="upload_keys")
     async def on_POST(
         self, request: SynapseRequest, device_id: Optional[str]
     ) -> Tuple[int, JsonDict]:
@@ -157,6 +156,7 @@ class KeyQueryServlet(RestServlet):
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
+    @cancellable
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user_id = requester.user.to_string()
@@ -200,6 +200,7 @@ class KeyChangesServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastores().main
 
+    @cancellable
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
@@ -208,7 +209,9 @@ class KeyChangesServlet(RestServlet):
 
         # We want to enforce they do pass us one, but we ignore it and return
         # changes after the "to" as well as before.
-        set_tag("to", parse_string(request, "to"))
+        #
+        # XXX This does not enforce that "to" is passed.
+        set_tag("to", str(parse_string(request, "to")))
 
         from_token = await StreamToken.from_string(self.store, from_token_string)
 
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index dd75e40f34..0437c87d8d 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,7 @@ from typing import (
 
 from typing_extensions import TypedDict
 
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
@@ -172,7 +172,13 @@ class LoginRestServlet(RestServlet):
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
-                appservice = self.auth.get_appservice_by_req(request)
+                requester = await self.auth.get_user_by_req(request)
+                appservice = requester.app_service
+
+                if appservice is None:
+                    raise InvalidClientTokenError(
+                        "This login method is only valid for application services"
+                    )
 
                 if appservice.is_rate_limited():
                     await self._address_ratelimiter.ratelimit(
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
new file mode 100644
index 0000000000..6278450c70
--- /dev/null
+++ b/synapse/rest/client/models.py
@@ -0,0 +1,85 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Dict, Optional
+
+from pydantic import Extra, StrictInt, StrictStr, constr, validator
+
+from synapse.rest.models import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+    """
+    Data used during user-interactive authentication.
+
+    (The name "Authentication Data" is taken directly from the spec.)
+
+    Additional keys will be present, depending on the `type` field. Use
+    `.dict(exclude_unset=True)` to access them.
+    """
+
+    class Config:
+        extra = Extra.allow
+
+    session: Optional[StrictStr] = None
+    type: Optional[StrictStr] = None
+
+
+class ThreePidRequestTokenBody(RequestBodyModel):
+    if TYPE_CHECKING:
+        client_secret: StrictStr
+    else:
+        # See also assert_valid_client_secret()
+        client_secret: constr(
+            regex="[0-9a-zA-Z.=_-]",  # noqa: F722
+            min_length=0,
+            max_length=255,
+            strict=True,
+        )
+
+    id_server: Optional[StrictStr]
+    id_access_token: Optional[StrictStr]
+    next_link: Optional[StrictStr]
+    send_attempt: StrictInt
+
+    @validator("id_access_token", always=True)
+    def token_required_for_identity_server(
+        cls, token: Optional[str], values: Dict[str, object]
+    ) -> Optional[str]:
+        if values.get("id_server") is not None and token is None:
+            raise ValueError("id_access_token is required if an id_server is supplied.")
+        return token
+
+
+class EmailRequestTokenBody(ThreePidRequestTokenBody):
+    email: StrictStr
+
+    # Canonicalise the email address. The addresses are all stored canonicalised
+    # in the database. This allows the user to reset his password without having to
+    # know the exact spelling (eg. upper and lower case) of address in the database.
+    # Without this, an email stored in the database as "foo@bar.com" would cause
+    # user requests for "FOO@bar.com" to raise a Not Found error.
+    _email_validator = validator("email", allow_reuse=True)(validate_email)
+
+
+if TYPE_CHECKING:
+    ISO3116_1_Alpha_2 = StrictStr
+else:
+    # Per spec: two-letter uppercase ISO-3166-1-alpha-2
+    ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
+
+
+class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
+    country: ISO3116_1_Alpha_2
+    phone_number: StrictStr
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 24bc7c9095..61268e3af1 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,11 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+            user_id,
+            [
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ],
         )
 
         notif_event_ids = [pa.event_id for pa in push_actions]
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index c16d707909..e69fa0829d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user = UserID.from_string(user_id)
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
 
         content = parse_json_object_from_request(request)
 
@@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
 
         content = parse_json_object_from_request(request)
         try:
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..5e53096539 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,6 +40,12 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.FULLY_READ,
+            ReceiptTypes.READ_PRIVATE,
+        }
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -49,13 +55,7 @@ class ReadMarkerRestServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        valid_receipt_types = {
-            ReceiptTypes.READ,
-            ReceiptTypes.FULLY_READ,
-            ReceiptTypes.READ_PRIVATE,
-        }
-
-        unrecognized_types = set(body.keys()) - valid_receipt_types
+        unrecognized_types = set(body.keys()) - self._known_receipt_types
         if unrecognized_types:
             # It's fine if there are unrecognized receipt types, but let's log
             # it to help debug clients that have typoed the receipt type.
@@ -65,31 +65,25 @@ class ReadMarkerRestServlet(RestServlet):
             # types.
             logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
 
-        read_event_id = body.get(ReceiptTypes.READ, None)
-        if read_event_id:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ,
-                user_id=requester.user.to_string(),
-                event_id=read_event_id,
-            )
-
-        read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
-        if read_private_event_id and self.config.experimental.msc2285_enabled:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ_PRIVATE,
-                user_id=requester.user.to_string(),
-                event_id=read_private_event_id,
-            )
-
-        read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
-        if read_marker_event_id:
-            await self.read_marker_handler.received_client_read_marker(
-                room_id,
-                user_id=requester.user.to_string(),
-                event_id=read_marker_event_id,
-            )
+        for receipt_type in self._known_receipt_types:
+            event_id = body.get(receipt_type, None)
+            # TODO Add validation to reject non-string event IDs.
+            if not event_id:
+                continue
+
+            if receipt_type == ReceiptTypes.FULLY_READ:
+                await self.read_marker_handler.received_client_read_marker(
+                    room_id,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
+            else:
+                await self.receipts_handler.received_client_receipt(
+                    room_id,
+                    receipt_type,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
 
         return 200, {}
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..5b7fad7402 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.FULLY_READ,
+        }
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
-            ReceiptTypes.READ,
-            ReceiptTypes.READ_PRIVATE,
-            ReceiptTypes.FULLY_READ,
-        ]:
+        if receipt_type not in self._known_receipt_types:
             raise SynapseError(
                 400,
-                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+                f"Receipt type must be {', '.join(self._known_receipt_types)}",
             )
-        elif (
-            not self.hs.config.experimental.msc2285_enabled
-            and receipt_type != ReceiptTypes.READ
-        ):
-            raise SynapseError(400, "Receipt type must be 'm.read'")
 
         parse_json_object_from_request(request, allow_empty_body=False)
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index e8e51a9c66..20bab20c8f 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -31,9 +31,8 @@ from synapse.api.errors import (
 )
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
@@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.config = hs.config
 
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if (
-                self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
-            ):
-                logger.warning(
-                    "Email registration has been disabled due to lack of email config"
-                )
+        if not self.hs.config.email.can_verify_email:
+            logger.warning(
+                "Email registration has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration has been disabled on this server"
             )
@@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send registration emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_registration_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send registration emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            email,
+            client_secret,
+            send_attempt,
+            self.mailer.send_registration_mail,
+            next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="register").observe(
             send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_registration_template_failure_html
             )
@@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet):
             raise SynapseError(
                 400, "This medium is currently not supported for registration"
             )
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User registration via email has been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User registration via email has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration is disabled on this server"
             )
@@ -325,7 +306,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.ratelimiter = FederationRateLimiter(
             hs.get_clock(),
-            FederationRateLimitConfig(
+            FederationRatelimitSettings(
                 # Time window of 2s
                 window_size=2000,
                 # Artificially delay requests if rate > sleep_limit/window_size
@@ -484,9 +465,6 @@ class RegisterRestServlet(RestServlet):
                     "Appservice token must be provided when using a type of m.login.application_service",
                 )
 
-            # Verify the AS
-            self.auth.get_appservice_by_req(request)
-
             # Set the desired user according to the AS API (which uses the
             # 'user' key not 'username'). Since this is a new addition, we'll
             # fallback to 'username' if they gave one.
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 2f513164cb..0bca012535 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,13 @@
 """ This module contains REST servlets to do with rooms: /rooms/<paths> """
 import logging
 import re
+from enum import Enum
+from http import HTTPStatus
 from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
 from urllib import parse as urlparse
 
+from prometheus_client.core import Histogram
+
 from twisted.web.server import Request
 
 from synapse import event_auth
@@ -34,7 +38,7 @@ from synapse.api.errors import (
 )
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer, cancellable
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
@@ -46,6 +50,7 @@ from synapse.http.servlet import (
     parse_strings_from_args,
 )
 from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag
 from synapse.rest.client._base import client_patterns
 from synapse.rest.client.transactions import HttpTransactionCache
@@ -53,6 +58,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
+from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
@@ -61,6 +67,70 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class _RoomSize(Enum):
+    """
+    Enum to differentiate sizes of rooms. This is a pretty good approximation
+    about how hard it will be to get events in the room. We could also look at
+    room "complexity".
+    """
+
+    # This doesn't necessarily mean the room is a DM, just that there is a DM
+    # amount of people there.
+    DM_SIZE = "direct_message_size"
+    SMALL = "small"
+    SUBSTANTIAL = "substantial"
+    LARGE = "large"
+
+    @staticmethod
+    def from_member_count(member_count: int) -> "_RoomSize":
+        if member_count <= 2:
+            return _RoomSize.DM_SIZE
+        elif member_count < 100:
+            return _RoomSize.SMALL
+        elif member_count < 1000:
+            return _RoomSize.SUBSTANTIAL
+        else:
+            return _RoomSize.LARGE
+
+
+# This is an extra metric on top of `synapse_http_server_response_time_seconds`
+# which times the same sort of thing but this one allows us to see values
+# greater than 10s. We use a separate dedicated histogram with its own buckets
+# so that we don't increase the cardinality of the general one because it's
+# multiplied across hundreds of servlets.
+messsages_response_timer = Histogram(
+    "synapse_room_message_list_rest_servlet_response_time_seconds",
+    "sec",
+    # We have a label for room size so we can try to see a more realistic
+    # picture of /messages response time for bigger rooms. We don't want the
+    # tiny rooms that can always respond fast skewing our results when we're trying
+    # to optimize the bigger cases.
+    ["room_size"],
+    buckets=(
+        0.005,
+        0.01,
+        0.025,
+        0.05,
+        0.1,
+        0.25,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        10.0,
+        20.0,
+        30.0,
+        60.0,
+        80.0,
+        100.0,
+        120.0,
+        150.0,
+        180.0,
+        "+Inf",
+    ),
+)
+
+
 class TransactionRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -165,7 +235,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
         msg_handler = self.message_handler
         data = await msg_handler.get_room_data(
-            user_id=requester.user.to_string(),
+            requester=requester,
             room_id=room_id,
             event_type=event_type,
             state_key=state_key,
@@ -510,7 +580,7 @@ class RoomMemberListRestServlet(RestServlet):
 
         events = await handler.get_state_events(
             room_id=room_id,
-            user_id=requester.user.to_string(),
+            requester=requester,
             at_token=at_token,
             state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
         )
@@ -556,6 +626,7 @@ class RoomMessageListRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._hs = hs
+        self.clock = hs.get_clock()
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -563,6 +634,18 @@ class RoomMessageListRestServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
+        processing_start_time = self.clock.time_msec()
+        # Fire off and hope that we get a result by the end.
+        #
+        # We're using the mypy type ignore comment because the `@cached`
+        # decorator on `get_number_joined_users_in_room` doesn't play well with
+        # the type system. Maybe in the future, it can use some ParamSpec
+        # wizardry to fix it up.
+        room_member_count_deferred = run_in_background(  # type: ignore[call-arg]
+            self.store.get_number_joined_users_in_room,
+            room_id,  # type: ignore[arg-type]
+        )
+
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = await PaginationConfig.from_request(
             self.store, request, default_limit=10
@@ -593,6 +676,12 @@ class RoomMessageListRestServlet(RestServlet):
             event_filter=event_filter,
         )
 
+        processing_end_time = self.clock.time_msec()
+        room_member_count = await make_deferred_yieldable(room_member_count_deferred)
+        messsages_response_timer.labels(
+            room_size=_RoomSize.from_member_count(room_member_count)
+        ).observe((processing_end_time - processing_start_time) / 1000)
+
         return 200, msgs
 
 
@@ -613,8 +702,7 @@ class RoomStateRestServlet(RestServlet):
         # Get all the current state for this room
         events = await self.message_handler.get_state_events(
             room_id=room_id,
-            user_id=requester.user.to_string(),
-            is_guest=requester.is_guest,
+            requester=requester,
         )
         return 200, events
 
@@ -672,7 +760,7 @@ class RoomEventServlet(RestServlet):
             == "true"
         )
         if include_unredacted_content and not await self.auth.is_server_admin(
-            requester.user
+            requester
         ):
             power_level_event = (
                 await self._storage_controllers.state.get_current_state_event(
@@ -860,7 +948,16 @@ class RoomMembershipRestServlet(TransactionRestServlet):
             # cheekily send invalid bodies.
             content = {}
 
-        if membership_action == "invite" and self._has_3pid_invite_keys(content):
+        if membership_action == "invite" and all(
+            key in content for key in ("medium", "address")
+        ):
+            if not all(key in content for key in ("id_server", "id_access_token")):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "`id_server` and `id_access_token` are required when doing 3pid invite",
+                    Codes.MISSING_PARAM,
+                )
+
             try:
                 await self.room_member_handler.do_3pid_invite(
                     room_id,
@@ -870,7 +967,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
                     content["id_server"],
                     requester,
                     txn_id,
-                    content.get("id_access_token"),
+                    content["id_access_token"],
                 )
             except ShadowBanError:
                 # Pretend the request succeeded.
@@ -907,12 +1004,6 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
         return 200, return_value
 
-    def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
-        for key in {"id_server", "medium", "address"}:
-            if key not in content:
-                return False
-        return True
-
     def on_PUT(
         self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
@@ -1177,9 +1268,7 @@ class TimestampLookupRestServlet(RestServlet):
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self._auth.get_user_by_req(request)
-        await self._auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string()
-        )
+        await self._auth.check_user_in_room_or_world_readable(room_id, requester)
 
         timestamp = parse_integer(request, "ts", required=True)
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 37e39570f6..f7081f638e 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         if session_id:
             body = {"sessions": {session_id: body}}
@@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
         user_id = requester.user.to_string()
         version = parse_string(request, "version", required=True)
 
-        room_keys = await self.e2e_room_keys_handler.get_room_keys(
-            user_id, version, room_id, session_id
+        room_keys = cast(
+            JsonDict,
+            await self.e2e_room_keys_handler.get_room_keys(
+                user_id, version, room_id, session_id
+            ),
         )
 
         # Convert room_keys to the right format to return.
@@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         ret = await self.e2e_room_keys_handler.delete_room_keys(
             user_id, version, room_id, session_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 3322c8ef48..46a8b03829 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
 from synapse.http.server import HttpServer
 from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace
+from synapse.logging.opentracing import set_tag
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import JsonDict
 
@@ -43,7 +43,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         self.txns = HttpTransactionCache(hs)
         self.device_message_handler = hs.get_device_message_handler()
 
-    @trace(opname="sendToDevice")
     def on_PUT(
         self, request: SynapseRequest, message_type: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8bbf35148d..c2989765ce 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    @trace(opname="sync.encode_response")
+    @trace_with_opname("sync.encode_response")
     async def encode_response(
         self,
         time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
             ]
         }
 
-    @trace(opname="sync.encode_joined")
+    @trace_with_opname("sync.encode_joined")
     async def encode_joined(
         self,
         rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    @trace(opname="sync.encode_invited")
+    @trace_with_opname("sync.encode_invited")
     async def encode_invited(
         self,
         rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    @trace(opname="sync.encode_knocked")
+    @trace_with_opname("sync.encode_knocked")
     async def encode_knocked(
         self,
         rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
 
         return knocked
 
-    @trace(opname="sync.encode_archived")
+    @trace_with_opname("sync.encode_archived")
     async def encode_archived(
         self,
         rooms: List[ArchivedSyncResult],
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index f4f06563dd..c516cda95d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -94,9 +94,9 @@ class VersionsRestServlet(RestServlet):
                     # Supports the busy presence state described in MSC3026.
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
                     # Supports receiving private read receipts as per MSC2285
-                    "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
-                    # Supports filtering of /publicRooms by room type MSC3827
-                    "org.matrix.msc3827": self.config.experimental.msc3827_enabled,
+                    "org.matrix.msc2285.stable": True,  # TODO: Remove when MSC2285 becomes a part of the spec
+                    # Supports filtering of /publicRooms by room type as per MSC3827
+                    "org.matrix.msc3827.stable": True,
                     # Adds support for importing historical messages as per MSC2716
                     "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
                     # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f597157581..7f8ad29566 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -135,13 +135,6 @@ class RemoteKey(DirectServeJsonResource):
 
         store_queries = []
         for server_name, key_ids in query.items():
-            if (
-                self.federation_domain_whitelist is not None
-                and server_name not in self.federation_domain_whitelist
-            ):
-                logger.debug("Federation denied with %s", server_name)
-                continue
-
             if not key_ids:
                 key_ids = (None,)
             for key_id in key_ids:
@@ -153,21 +146,28 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        # Note that the value is unused.
+        # Map server_name->key_id->int. Note that the value of the init is unused.
+        # XXX: why don't we just use a set?
         cache_misses: Dict[str, Dict[str, int]] = {}
         for (server_name, key_id, _), key_results in cached.items():
             results = [(result["ts_added_ms"], result) for result in key_results]
 
-            if not results and key_id is not None:
-                cache_misses.setdefault(server_name, {})[key_id] = 0
+            if key_id is None:
+                # all keys were requested. Just return what we have without worrying
+                # about validity
+                for _, result in results:
+                    # Cast to bytes since postgresql returns a memoryview.
+                    json_results.add(bytes(result["key_json"]))
                 continue
 
-            if key_id is not None:
+            miss = False
+            if not results:
+                miss = True
+            else:
                 ts_added_ms, most_recent_result = max(results)
                 ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
                 req_key = query.get(server_name, {}).get(key_id, {})
                 req_valid_until = req_key.get("minimum_valid_until_ts")
-                miss = False
                 if req_valid_until is not None:
                     if ts_valid_until_ms < req_valid_until:
                         logger.debug(
@@ -211,19 +211,20 @@ class RemoteKey(DirectServeJsonResource):
                         ts_valid_until_ms,
                         time_now_ms,
                     )
-
-                if miss:
-                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
-            else:
-                for _, result in results:
-                    # Cast to bytes since postgresql returns a memoryview.
-                    json_results.add(bytes(result["key_json"]))
+
+            if miss and query_remote_on_cache_miss:
+                # only bother attempting to fetch keys from servers on our whitelist
+                if (
+                    self.federation_domain_whitelist is None
+                    or server_name in self.federation_domain_whitelist
+                ):
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
 
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
-        if cache_misses and query_remote_on_cache_miss:
+        if cache_misses:
             await yieldable_gather_results(
                 lambda t: self.fetcher.get_keys(*t),
                 (
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index c35d42fab8..d30878f704 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -254,30 +254,32 @@ async def respond_with_responder(
         file_size: Size in bytes of the media. If not known it should be None
         upload_name: The name of the requested file, if any.
     """
-    if request._disconnected:
-        logger.warning(
-            "Not sending response to request %s, already disconnected.", request
-        )
-        return
-
     if not responder:
         respond_404(request)
         return
 
-    logger.debug("Responding to media request with responder %s", responder)
-    add_file_headers(request, media_type, file_size, upload_name)
-    try:
-        with responder:
+    # If we have a responder we *must* use it as a context manager.
+    with responder:
+        if request._disconnected:
+            logger.warning(
+                "Not sending response to request %s, already disconnected.", request
+            )
+            return
+
+        logger.debug("Responding to media request with responder %s", responder)
+        add_file_headers(request, media_type, file_size, upload_name)
+        try:
+
             await responder.write_to_consumer(request)
-    except Exception as e:
-        # The majority of the time this will be due to the client having gone
-        # away. Unfortunately, Twisted simply throws a generic exception at us
-        # in that case.
-        logger.warning("Failed to write to consumer: %s %s", type(e), e)
-
-        # Unregister the producer, if it has one, so Twisted doesn't complain
-        if request.producer:
-            request.unregisterProducer()
+        except Exception as e:
+            # The majority of the time this will be due to the client having gone
+            # away. Unfortunately, Twisted simply throws a generic exception at us
+            # in that case.
+            logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+            # Unregister the producer, if it has one, so Twisted doesn't complain
+            if request.producer:
+                request.unregisterProducer()
 
     finish_request(request)
 
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 7435fd9130..9dd3c8d4bb 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -64,7 +64,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
 # How often to run the background job to update the "recently accessed"
 # attribute of local and remote media.
 UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000  # 1 minute
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 54a849eac9..a8f6fd6b35 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -109,10 +109,64 @@ class MediaInfo:
 
 class PreviewUrlResource(DirectServeJsonResource):
     """
-    Generating URL previews is a complicated task which many potential pitfalls.
-
-    See docs/development/url_previews.md for discussion of the design and
-    algorithm followed in this module.
+    The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
+    for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+    specific additions).
+
+    This does have trade-offs compared to other designs:
+
+    * Pros:
+      * Simple and flexible; can be used by any clients at any point
+    * Cons:
+      * If each homeserver provides one of these independently, all the homeservers in a
+        room may needlessly DoS the target URI
+      * The URL metadata must be stored somewhere, rather than just using Matrix
+        itself to store the media.
+      * Matrix cannot be used to distribute the metadata between homeservers.
+
+    When Synapse is asked to preview a URL it does the following:
+
+    1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the
+       config).
+    2. Checks the URL against an in-memory cache and returns the result if it exists. (This
+       is also used to de-duplicate processing of multiple in-flight requests at once.)
+    3. Kicks off a background process to generate a preview:
+       1. Checks URL and timestamp against the database cache and returns the result if it
+          has not expired and was successful (a 2xx return code).
+       2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it
+          does, update the URL to download.
+       3. Downloads the URL and stores it into a file via the media storage provider
+          and saves the local media metadata.
+       4. If the media is an image:
+          1. Generates thumbnails.
+          2. Generates an Open Graph response based on image properties.
+       5. If the media is HTML:
+          1. Decodes the HTML via the stored file.
+          2. Generates an Open Graph response from the HTML.
+          3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+             1. Downloads the URL and stores it into a file via the media storage provider
+                and saves the local media metadata.
+             2. Convert the oEmbed response to an Open Graph response.
+             3. Override any Open Graph data from the HTML with data from oEmbed.
+          4. If an image exists in the Open Graph response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       6. If the media is JSON and an oEmbed URL was found:
+          1. Convert the oEmbed response to an Open Graph response.
+          2. If a thumbnail or image is in the oEmbed response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       7. Stores the result in the database cache.
+    4. Returns the result.
+
+    The in-memory cache expires after 1 hour.
+
+    Expired entries in the database cache (and their associated media files) are
+    deleted every 10 seconds. The default expiration time is 1 hour from download.
     """
 
     isLeaf = True
@@ -678,10 +732,6 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         logger.debug("Running url preview cache expiry")
 
-        if not (await self.store.db_pool.updates.has_completed_background_updates()):
-            logger.debug("Still running DB updates; skipping url preview cache expiry")
-            return
-
         def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
             """Attempt to remove the given chain of parent directories
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 2295adfaa7..5f725c7600 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,9 +17,11 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
 from synapse.http.server import (
     DirectServeJsonResource,
+    respond_with_json,
     set_corp_headers,
     set_cors_headers,
 )
@@ -309,6 +311,19 @@ class ThumbnailResource(DirectServeJsonResource):
             url_cache: True if this is from a URL cache.
             server_name: The server name, if this is a remote thumbnail.
         """
+        logger.debug(
+            "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            thumbnail_infos,
+        )
+
+        # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+        # different code path to handle it.
+        assert not self.dynamic_thumbnails
+
         if thumbnail_infos:
             file_info = self._select_thumbnail(
                 desired_width,
@@ -384,8 +399,29 @@ class ThumbnailResource(DirectServeJsonResource):
                 file_info.thumbnail.length,
             )
         else:
+            # This might be because:
+            # 1. We can't create thumbnails for the given media (corrupted or
+            #    unsupported file type), or
+            # 2. The thumbnailing process never ran or errored out initially
+            #    when the media was first uploaded (these bugs should be
+            #    reported and fixed).
+            # Note that we don't attempt to generate a thumbnail now because
+            # `dynamic_thumbnails` is disabled.
             logger.info("Failed to find any generated thumbnails")
-            respond_404(request)
+
+            respond_with_json(
+                request,
+                400,
+                cs_error(
+                    "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+                    % (
+                        request.postpath,
+                        ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+                    ),
+                    code=Codes.UNKNOWN,
+                ),
+                send_cors=True,
+            )
 
     def _select_thumbnail(
         self,
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
new file mode 100644
index 0000000000..ac39cda8e5
--- /dev/null
+++ b/synapse/rest/models.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Extra
+
+
+class RequestBodyModel(BaseModel):
+    """A custom version of Pydantic's BaseModel which
+
+     - ignores unknown fields and
+     - does not allow fields to be overwritten after construction,
+
+    but otherwise uses Pydantic's default behaviour.
+
+    Ignoring unknown fields is a useful default. It means that clients can provide
+    unstable field not known to the server without the request being refused outright.
+
+    Subclassing in this way is recommended by
+    https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+    """
+
+    class Config:
+        # By default, ignore fields that we don't recognise.
+        extra = Extra.ignore
+        # By default, don't allow fields to be reassigned after parsing.
+        allow_mutation = False
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 6ac9dbc7c9..b9402cfb75 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.api.errors import ThreepidValidationError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.server import DirectServeHtmlResource
 from synapse.http.servlet import parse_string
 from synapse.util.stringutils import assert_valid_client_secret
@@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        self._local_threepid_handling_disabled_due_to_email_config = (
-            hs.config.email.local_threepid_handling_disabled_due_to_email_config
-        )
         self._confirmation_email_template = (
             hs.config.email.email_password_reset_template_confirmation_html
         )
@@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
             hs.config.email.email_password_reset_template_failure_html
         )
 
-        # This resource should not be mounted if threepid behaviour is not LOCAL
-        assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+        # This resource should only be mounted if email validation is enabled
+        assert hs.config.email.can_verify_email
 
     async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
         sid = parse_string(request, "sid", required=True)
diff --git a/synapse/server.py b/synapse/server.py
index 181984a1a4..df3a1cb405 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -105,6 +105,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
 from synapse.handlers.user_directory import UserDirectoryHandler
 from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
 from synapse.module_api import ModuleApi
 from synapse.notifier import Notifier
 from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
@@ -340,7 +341,17 @@ class HomeServer(metaclass=abc.ABCMeta):
         return domain_specific_string.domain == self.hostname
 
     def is_mine_id(self, string: str) -> bool:
-        return string.split(":", 1)[1] == self.hostname
+        """Determines whether a user ID or room alias originates from this homeserver.
+
+        Returns:
+            `True` if the hostname part of the user ID or room alias matches this
+            homeserver.
+            `False` otherwise, or if the user ID or room alias is malformed.
+        """
+        localpart_hostname = string.split(":", 1)
+        if len(localpart_hostname) < 2:
+            return False
+        return localpart_hostname[1] == self.hostname
 
     @cache_in_self
     def get_clock(self) -> Clock:
@@ -756,7 +767,9 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_federation_ratelimiter(self) -> FederationRateLimiter:
         return FederationRateLimiter(
-            self.get_clock(), config=self.config.ratelimiting.rc_federation
+            self.get_clock(),
+            config=self.config.ratelimiting.rc_federation,
+            metrics_name="federation_servlets",
         )
 
     @cache_in_self
@@ -827,3 +840,8 @@ class HomeServer(metaclass=abc.ABCMeta):
             self.config.ratelimiting.rc_message,
             self.config.ratelimiting.rc_admin_redaction,
         )
+
+    @cache_in_self
+    def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
+        """Usage metrics shared between phone home stats and the prometheus exporter."""
+        return CommonUsageMetricsManager(self)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 8ecab86ec7..564e3705c2 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -102,6 +102,10 @@ class ServerNoticesManager:
         Returns:
             The room's ID, or None if no room could be found.
         """
+        # If there is no server notices MXID, then there is no server notices room
+        if self.server_notices_mxid is None:
+            return None
+
         rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE, Membership.JOIN]
         )
@@ -111,8 +115,10 @@ class ServerNoticesManager:
             # be joined. This is kinda deliberate, in that if somebody somehow
             # manages to invite the system user to a room, that doesn't make it
             # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+            is_server_notices_room = await self._store.check_local_user_in_room(
+                user_id=self.server_notices_mxid, room_id=room.room_id
+            )
+            if is_server_notices_room:
                 # we found a room which our user shares with the system notice
                 # user
                 return room.room_id
@@ -244,7 +250,7 @@ class ServerNoticesManager:
         assert self.server_notices_mxid is not None
 
         notice_user_data_in_room = await self._message_handler.get_room_data(
-            self.server_notices_mxid,
+            create_requester(self.server_notices_mxid),
             room_id,
             EventTypes.Member,
             self.server_notices_mxid,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 781d9f06da..3787d35b24 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -24,14 +24,12 @@ from typing import (
     DefaultDict,
     Dict,
     FrozenSet,
-    Iterable,
     List,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
-    Union,
 )
 
 import attr
@@ -46,7 +44,7 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -54,6 +52,7 @@ from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.controllers import StateStorageController
     from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
@@ -83,17 +82,26 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
-        state: StateMap[str],
+        state: Optional[StateMap[str]],
         state_group: Optional[int],
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
+        if state is None and state_group is None and prev_group is None:
+            raise Exception("One of state, state_group or prev_group must be not None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("If prev_group is set so must delta_ids")
+
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state)
+        #
+        # This can be None if we have a `state_group` (as then we can fetch the
+        # state from the DB.)
+        self._state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -102,20 +110,60 @@ class _StateCacheEntry:
         self.prev_group = prev_group
         self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
-        # The `state_id` is a unique ID we generate that can be used as ID for
-        # this collection of state. Usually this would be the same as the
-        # state group, but on worker instances we can't generate a new state
-        # group each time we resolve state, so we generate a separate one that
-        # isn't persisted and is used solely for caches.
-        # `state_id` is either a state_group (and so an int) or a string. This
-        # ensures we don't accidentally persist a state_id as a stateg_group
-        if state_group:
-            self.state_id: Union[str, int] = state_group
-        else:
-            self.state_id = _gen_state_id()
+    async def get_state(
+        self,
+        state_storage: "StateStorageController",
+        state_filter: Optional["StateFilter"] = None,
+    ) -> StateMap[str]:
+        """Get the state map for this entry, either from the in-memory state or
+        looking up the state group in the DB.
+        """
+
+        if self._state is not None:
+            return self._state
+
+        if self.state_group is not None:
+            return await state_storage.get_state_ids_for_group(
+                self.state_group, state_filter
+            )
+
+        assert self.prev_group is not None and self.delta_ids is not None
+
+        prev_state = await state_storage.get_state_ids_for_group(
+            self.prev_group, state_filter
+        )
+
+        # ChainMap expects MutableMapping, but since we're using it immutably
+        # its safe to give it immutable maps.
+        return ChainMap(self.delta_ids, prev_state)  # type: ignore[arg-type]
+
+    def set_state_group(self, state_group: int) -> None:
+        """Update the state group assigned to this state (e.g. after we've
+        persisted it).
+
+        Note: this will cause the cache entry to drop any stored state.
+        """
+
+        self.state_group = state_group
+
+        # We clear out the state as we know longer need to explicitly keep it in
+        # the `state_cache` (as the store state group cache will do that).
+        self._state = None
 
     def __len__(self) -> int:
-        return len(self.state)
+        # The len should be used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why it's fine to return 1 if we're
+        # not storing any state.
+
+        length = 0
+
+        if self._state:
+            length += len(self._state)
+
+        if self.delta_ids:
+            length += len(self.delta_ids)
+
+        return length or 1  # Make sure its not 0.
 
 
 class StateHandler:
@@ -137,29 +185,35 @@ class StateHandler:
             ReplicationUpdateCurrentStateRestServlet.make_client(hs)
         )
 
-    async def get_current_state_ids(
+    async def compute_state_after_events(
         self,
         room_id: str,
-        latest_event_ids: Collection[str],
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
     ) -> StateMap[str]:
-        """Get the current state, or the state at a set of events, for a room
+        """Fetch the state after each of the given event IDs. Resolve them and return.
+
+        This is typically used where `event_ids` is a collection of forward extremities
+        in a room, intended to become the `prev_events` of a new event E. If so, the
+        return value of this function represents the state before E.
 
         Args:
-            room_id:
-            latest_event_ids: The forward extremities to resolve.
+            room_id: the room_id containing the given events.
+            event_ids: the events whose state should be fetched and resolved.
 
         Returns:
-            the state dict, mapping from (event_type, state_key) -> event_id
+            the state dict (a mapping from (event_type, state_key) -> event_id) which
+            holds the resolution of the states after the given event IDs.
         """
-        logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return ret.state
+        logger.debug("calling resolve_state_groups from compute_state_after_events")
+        ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+        return await ret.get_state(self._state_storage_controller, state_filter)
 
-    async def get_current_users_in_room(
+    async def get_current_user_ids_in_room(
         self, room_id: str, latest_event_ids: List[str]
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         """
-        Get the users who are currently in a room.
+        Get the users IDs who are currently in a room.
 
         Note: This is much slower than using the equivalent method
         `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -170,14 +224,15 @@ class StateHandler:
             room_id: The ID of the room.
             latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
         Returns:
-            Dictionary of user IDs to their profileinfo.
+            Set of user IDs in the room.
         """
 
         assert latest_event_ids is not None
 
-        logger.debug("calling resolve_state_groups from get_current_users_in_room")
+        logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return await self.store.get_joined_users_from_state(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_user_ids_from_state(room_id, state)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
@@ -192,13 +247,14 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        return await self.store.get_joined_hosts(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_hosts(room_id, state, entry)
 
     async def compute_event_context(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
-        partial_state: bool = False,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
 
@@ -213,10 +269,18 @@ class StateHandler:
                 it can't be calculated from existing events. This is normally
                 only specified when receiving an event from federation where we
                 don't have the prev events, e.g. when backfilling.
-            partial_state: True if `state_ids_before_event` is partial and omits
-                non-critical membership events
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
         Returns:
             The event context.
+
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
         """
 
         assert not event.internal_metadata.is_outlier()
@@ -227,17 +291,28 @@ class StateHandler:
         #
         if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
-            entry = None
 
+            # .. though we need to get a state group for it.
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=None,
+                    delta_ids=None,
+                    current_state_ids=state_ids_before_event,
+                )
+            )
+
+            # the partial_state flag must be provided
+            assert partial_state is not None
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
 
             # partial_state should not be set explicitly in this case:
             # we work it out dynamically
-            assert not partial_state
+            assert partial_state is None
 
             # if any of the prev-events have partial state, so do we.
             # (This is slightly racy - the prev-events might get fixed up before we use
@@ -247,13 +322,13 @@ class StateHandler:
             incomplete_prev_events = await self.store.get_partial_state_events(
                 prev_event_ids
             )
-            if any(incomplete_prev_events.values()):
+            partial_state = any(incomplete_prev_events.values())
+            if partial_state:
                 logger.debug(
                     "New/incoming event %s refers to prev_events %s with partial state",
                     event.event_id,
                     [k for (k, v) in incomplete_prev_events.items() if v],
                 )
-                partial_state = True
 
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for
@@ -264,36 +339,32 @@ class StateHandler:
                 await_full_state=False,
             )
 
-            state_ids_before_event = entry.state
-            state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
+
+            # We make sure that we have a state group assigned to the state.
+            if entry.state_group is None:
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
 
-        #
-        # make sure that we have a state group at that point. If it's not a state event,
-        # that will be the state group for the new event. If it *is* a state event,
-        # it might get rejected (in which case we'll need to persist it with the
-        # previous state group)
-        #
-
-        if not state_group_before_event:
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=state_group_before_event_prev_group,
-                    delta_ids=deltas_to_state_group_before_event,
-                    current_state_ids=state_ids_before_event,
+                state_group_before_event = (
+                    await self._state_storage_controller.store_state_group(
+                        event.event_id,
+                        event.room_id,
+                        prev_group=state_group_before_event_prev_group,
+                        delta_ids=deltas_to_state_group_before_event,
+                        current_state_ids=state_ids_before_event,
+                    )
                 )
-            )
-
-            # Assign the new state group to the cached state entry.
-            #
-            # Note that this can race in that we could generate multiple state
-            # groups for the same state entry, but that is just inefficient
-            # rather than dangerous.
-            if entry and entry.state_group is None:
-                entry.state_group = state_group_before_event
+                entry.set_state_group(state_group_before_event)
+            else:
+                state_group_before_event = entry.state_group
 
         #
         # now if it's not a state event, we're done
@@ -315,13 +386,18 @@ class StateHandler:
         #
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -330,7 +406,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )
 
@@ -359,6 +435,10 @@ class StateHandler:
 
         Returns:
             The resolved state
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -372,9 +452,6 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self._state_storage_controller.get_state_for_groups(
-                state_group_ids_set
-            )
             (
                 prev_group,
                 delta_ids,
@@ -382,7 +459,7 @@ class StateHandler:
                 state_group_id
             )
             return _StateCacheEntry(
-                state=state[state_group_id],
+                state=None,
                 state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
@@ -405,31 +482,6 @@ class StateHandler:
         )
         return result
 
-    async def resolve_events(
-        self,
-        room_version: str,
-        state_sets: Collection[Iterable[EventBase]],
-        event: EventBase,
-    ) -> StateMap[EventBase]:
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [
-            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
-        ]
-
-        state_map = {ev.event_id: ev for st in state_sets for ev in st}
-
-        new_state = await self._state_resolution_handler.resolve_events_with_store(
-            event.room_id,
-            room_version,
-            state_set_ids,
-            event_map=state_map,
-            state_res_store=StateResolutionStore(self.store),
-        )
-
-        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
     async def update_current_state(self, room_id: str) -> None:
         """Recalculates the current state for a room, and persists it.
 
@@ -741,7 +793,7 @@ def _make_state_cache_entry(
         old_state_event_ids = set(state.values())
         if new_state_event_ids == old_state_event_ids:
             # got an exact match.
-            return _StateCacheEntry(state=new_state, state_group=sg)
+            return _StateCacheEntry(state=None, state_group=sg)
 
     # TODO: We want to create a state group for this set of events, to
     # increase cache hits, but we need to make sure that it doesn't
@@ -752,14 +804,25 @@ def _make_state_cache_entry(
     delta_ids: Optional[StateMap[str]] = None
 
     for old_group, old_state in state_groups_ids.items():
+        if old_state.keys() - new_state.keys():
+            # Currently we don't support deltas that remove keys from the state
+            # map, so we have to ignore this group as a candidate to base the
+            # new group on.
+            continue
+
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
             delta_ids = n_delta_ids
 
-    return _StateCacheEntry(
-        state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
-    )
+    if prev_group is not None:
+        # If we have a prev group and deltas then we can drop the new state from
+        # the cache (to reduce memory usage).
+        return _StateCacheEntry(
+            state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+        )
+    else:
+        return _StateCacheEntry(state=new_state, state_group=None)
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7db032203b..af03851c71 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -271,40 +271,41 @@ async def _get_power_level_for_sender(
 async def _get_auth_chain_difference(
     room_id: str,
     state_sets: Sequence[Mapping[Any, str]],
-    event_map: Dict[str, EventBase],
+    unpersisted_events: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
-    that only appear in some but not all of the auth chains.
+    that only appear in some, but not all of the auth chains.
 
     Args:
-        state_sets
-        event_map
-        state_res_store
+        state_sets: The input state sets we are trying to resolve across.
+        unpersisted_events: A map from event ID to EventBase containing all unpersisted
+            events involved in this resolution.
+        state_res_store:
 
     Returns:
-        Set of event IDs
+        The auth difference of the given state sets, as a set of event IDs.
     """
 
     # The `StateResolutionStore.get_auth_chain_difference` function assumes that
     # all events passed to it (and their auth chains) have been persisted
-    # previously. This is not the case for any events in the `event_map`, and so
-    # we need to manually handle those events.
+    # previously. We need to manually handle any other events that are yet to be
+    # persisted.
     #
-    # We do this by:
-    #   1. calculating the auth chain difference for the state sets based on the
-    #      events in `event_map` alone
-    #   2. replacing any events in the state_sets that are also in `event_map`
-    #      with their auth events (recursively), and then calling
-    #      `store.get_auth_chain_difference` as normal
-    #   3. adding the results of 1 and 2 together.
-
-    # Map from event ID in `event_map` to their auth event IDs, and their auth
-    # event IDs if they appear in the `event_map`. This is the intersection of
-    # the event's auth chain with the events in the `event_map` *plus* their
+    # We do this in three steps:
+    #   1. Compute the set of unpersisted events belonging to the auth difference.
+    #   2. Replacing any unpersisted events in the state_sets with their auth events,
+    #      recursively, until the state_sets contain only persisted events.
+    #      Then we call `store.get_auth_chain_difference` as normal, which computes
+    #      the set of persisted events belonging to the auth difference.
+    #   3. Adding the results of 1 and 2 together.
+
+    # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
+    # event IDs if they appear in the `unpersisted_events`. This is the intersection of
+    # the event's auth chain with the events in `unpersisted_events` *plus* their
     # auth event IDs.
     events_to_auth_chain: Dict[str, Set[str]] = {}
-    for event in event_map.values():
+    for event in unpersisted_events.values():
         chain = {event.event_id}
         events_to_auth_chain[event.event_id] = chain
 
@@ -312,16 +313,16 @@ async def _get_auth_chain_difference(
         while to_search:
             for auth_id in to_search.pop().auth_event_ids():
                 chain.add(auth_id)
-                auth_event = event_map.get(auth_id)
+                auth_event = unpersisted_events.get(auth_id)
                 if auth_event:
                     to_search.append(auth_event)
 
-    # We now a) calculate the auth chain difference for the unpersisted events
-    # and b) work out the state sets to pass to the store.
+    # We now 1) calculate the auth chain difference for the unpersisted events
+    # and 2) work out the state sets to pass to the store.
     #
-    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # Note: If there are no `unpersisted_events` (which is the common case), we can do a
     # much simpler calculation.
-    if event_map:
+    if unpersisted_events:
         # The list of state sets to pass to the store, where each state set is a set
         # of the event ids making up the state. This is similar to `state_sets`,
         # except that (a) we only have event ids, not the complete
@@ -344,14 +345,18 @@ async def _get_auth_chain_difference(
             for event_id in state_set.values():
                 event_chain = events_to_auth_chain.get(event_id)
                 if event_chain is not None:
-                    # We have an event in `event_map`. We add all the auth
-                    # events that it references (that aren't also in `event_map`).
-                    set_ids.update(e for e in event_chain if e not in event_map)
+                    # We have an unpersisted event. We add all the auth
+                    # events that it references which are also unpersisted.
+                    set_ids.update(
+                        e for e in event_chain if e not in unpersisted_events
+                    )
 
                     # We also add the full chain of unpersisted event IDs
                     # referenced by this state set, so that we can work out the
                     # auth chain difference of the unpersisted events.
-                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                    unpersisted_ids.update(
+                        e for e in event_chain if e in unpersisted_events
+                    )
                 else:
                     set_ids.add(event_id)
 
@@ -361,15 +366,15 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        difference_from_event_map: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: Collection[str] = union - intersection
     else:
-        difference_from_event_map = ()
+        auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
     difference = await state_res_store.get_auth_chain_difference(
         room_id, state_sets_ids
     )
-    difference.update(difference_from_event_map)
+    difference.update(auth_difference_unpersisted_part)
 
     return difference
 
@@ -434,7 +439,7 @@ async def _add_event_and_auth_chain_to_graph(
     event_id: str,
     event_map: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
-    auth_diff: Set[str],
+    full_conflicted_set: Set[str],
 ) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
@@ -445,7 +450,7 @@ async def _add_event_and_auth_chain_to_graph(
         event_id: Event to add to the graph
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
     """
 
     state = [event_id]
@@ -455,7 +460,7 @@ async def _add_event_and_auth_chain_to_graph(
 
         event = await _get_event(room_id, eid, event_map, state_res_store)
         for aid in event.auth_event_ids():
-            if aid in auth_diff:
+            if aid in full_conflicted_set:
                 if aid not in graph:
                     state.append(aid)
 
@@ -468,7 +473,7 @@ async def _reverse_topological_power_sort(
     event_ids: Iterable[str],
     event_map: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
-    auth_diff: Set[str],
+    full_conflicted_set: Set[str],
 ) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
@@ -479,7 +484,7 @@ async def _reverse_topological_power_sort(
         event_ids: The events to sort
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
 
     Returns:
         The sorted list
@@ -488,7 +493,7 @@ async def _reverse_topological_power_sort(
     graph: Dict[str, Set[str]] = {}
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
-            graph, room_id, event_id, event_map, state_res_store, auth_diff
+            graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
         )
 
         # We await occasionally when we're working with large data sets to
diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html
index 9e6daf38ac..40510889ac 100644
--- a/synapse/static/client/login/index.html
+++ b/synapse/static/client/login/index.html
@@ -3,7 +3,8 @@
 <head>
     <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
     <title> Login </title>
-    <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
     <link rel="stylesheet" href="style.css">
     <script src="js/jquery-3.4.1.min.js"></script>
     <script src="js/login.js"></script>
diff --git a/synapse/static/client/register/index.html b/synapse/static/client/register/index.html
index 140653574d..27bbd76f51 100644
--- a/synapse/static/client/register/index.html
+++ b/synapse/static/client/register/index.html
@@ -2,7 +2,8 @@
 <html>
 <head>
 <title> Registration </title>
-<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> 
+<meta http-equiv="X-UA-Compatible" content="IE=edge">
+<meta name="viewport" content="width=device-width, initial-scale=1.0">
 <link rel="stylesheet" href="style.css">
 <script src="js/jquery-3.4.1.min.js"></script>
 <script src="https://www.recaptcha.net/recaptcha/api/js/recaptcha_ajax.js"></script>
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b8c8dcd76b..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
             )
             self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
 
+            # There's no easy way of invalidating this cache for just the users
+            # that have changed, so we just clear the entire thing.
+            self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
         for user_id in members_changed:
             self._attempt_to_invalidate_cache(
                 "get_user_in_room_with_profile", (room_id, user_id)
@@ -96,6 +100,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
+        Note that this function does not invalidate any remote caches, only the
+        local in-memory ones. Any remote invalidation must be performed before
+        calling this.
+
         Args:
             cache_name
             key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +120,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         if key is None:
             cache.invalidate_all()
         else:
-            cache.invalidate(tuple(key))
+            # Prefer any local-only invalidation method. Invalidating any non-local
+            # cache must be be done before this.
+            invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+            invalidate_method(tuple(key))
 
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 555b4e77d2..cf1eabc437 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -581,9 +581,6 @@ class BackgroundUpdater:
         def create_index_sqlite(conn: Connection) -> None:
             # Sqlite doesn't support concurrent creation of indexes.
             #
-            # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy and CentOS 7 have 3.7
-            #
             # We assume that sqlite doesn't give us invalid indices; however
             # we may still end up with the index existing but the
             # background_updates not having been recorded if synapse got shut
@@ -591,12 +588,13 @@ class BackgroundUpdater:
             # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
             sql = (
                 "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
-                " (%(columns)s)"
+                " (%(columns)s) %(where_clause)s"
             ) % {
                 "unique": "UNIQUE" if unique else "",
                 "name": index_name,
                 "table": table,
                 "columns": ", ".join(columns),
+                "where_clause": "WHERE " + where_clause if where_clause else "",
             }
 
             c = conn.cursor()
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
 
         self.persistence = None
         if stores.persist_events:
-            self.persistence = EventsPersistenceStorageController(hs, stores)
+            self.persistence = EventsPersistenceStorageController(
+                hs, stores, self.state
+            )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..501dbbc990 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,12 +45,20 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    set_tag,
+    start_active_span_follows_from,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -221,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             queue.append(end_item)
 
         # also add our active opentracing span to the item so that we get a link back
-        span = opentracing.active_span()
+        span = active_span()
         if span:
             end_item.parent_opentracing_span_contexts.append(span.context)
 
@@ -232,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         res = await make_deferred_yieldable(end_item.deferred.observe())
 
         # add another opentracing span which links to the persist trace.
-        with opentracing.start_active_span_follows_from(
+        with start_active_span_follows_from(
             f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
@@ -264,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        with opentracing.start_active_span_follows_from(
+                        with start_active_span_follows_from(
                             item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
@@ -308,7 +316,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -325,6 +338,7 @@ class EventsPersistenceStorageController:
             self._process_event_persist_queue_task
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
     async def _process_event_persist_queue_task(
         self,
@@ -347,7 +361,7 @@ class EventsPersistenceStorageController:
                 f"Found an unexpected task type in event persistence queue: {task}"
             )
 
-    @opentracing.trace
+    @trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -372,9 +386,21 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        event_ids: List[str] = []
         partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
+            event_ids.append(event.event_id)
+
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
 
         async def enqueue(
             item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -410,7 +436,7 @@ class EventsPersistenceStorageController:
             self.main_store.get_room_max_token(),
         )
 
-    @opentracing.trace
+    @trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
@@ -504,7 +530,7 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
         self, _room_id: str, task: _PersistEventsTask
@@ -572,9 +598,9 @@ class EventsPersistenceStorageController:
             # room
             state_delta_for_room: Dict[str, DeltaState] = {}
 
-            # Set of remote users which were in rooms the server has left. We
-            # should check if we still share any rooms and if not we mark their
-            # device lists as stale.
+            # Set of remote users which were in rooms the server has left or who may
+            # have left rooms the server is in. We should check if we still share any
+            # rooms and if not we mark their device lists as stale.
             potentially_left_users: Set[str] = set()
 
             if not backfilled:
@@ -699,6 +725,20 @@ class EventsPersistenceStorageController:
                                 current_state = {}
                                 delta.no_longer_in_room = True
 
+                            # Add all remote users that might have left rooms.
+                            potentially_left_users.update(
+                                user_id
+                                for event_type, user_id in delta.to_delete
+                                if event_type == EventTypes.Member
+                                and not self.is_mine_id(user_id)
+                            )
+                            potentially_left_users.update(
+                                user_id
+                                for event_type, user_id in delta.to_insert.keys()
+                                if event_type == EventTypes.Member
+                                and not self.is_mine_id(user_id)
+                            )
+
                             state_delta_for_room[room_id] = delta
 
             await self.persist_events_store._persist_events_and_state_updates(
@@ -940,7 +980,8 @@ class EventsPersistenceStorageController:
                 events_context,
             )
 
-        return res.state, None, new_latest_event_ids
+        full_state = await res.get_state(self._state_controller)
+        return full_state, None, new_latest_event_ids
 
     async def _prune_extremities(
         self,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index d3a44bc876..bbe568bf05 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,18 +23,20 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
     PartialStateEventsTracker,
 )
 from synapse.types import MutableStateMap, StateMap
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -82,13 +84,15 @@ class StateStorageController:
         return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
-        self, _room_id: str, event_ids: Collection[str]
+        self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
             _room_id: id of the room for these events
             event_ids: ids of the events
+            await_full_state: if `True`, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +104,9 @@ class StateStorageController:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -175,6 +181,7 @@ class StateStorageController:
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
+    @trace
     async def get_state_for_events(
         self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
@@ -221,10 +228,14 @@ class StateStorageController:
 
         return {event: event_to_state[event] for event in event_ids}
 
+    @trace
+    @tag_args
+    @cancellable
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -233,6 +244,9 @@ class StateStorageController:
         Args:
             event_ids: events whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at these events and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from event_id -> (type, state_key) -> event_id
@@ -241,8 +255,12 @@ class StateStorageController:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
+            # Full state is not required if the state filter is restrictive enough.
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -283,8 +301,12 @@ class StateStorageController:
         )
         return state_map[event_id]
 
+    @trace
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -292,6 +314,9 @@ class StateStorageController:
         Args:
             event_id: event whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from (type, state_key) -> state_event_id
@@ -301,7 +326,9 @@ class StateStorageController:
                 outlier or is unknown)
         """
         state_map = await self.get_state_ids_for_events(
-            [event_id], state_filter or StateFilter.all()
+            [event_id],
+            state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
         return state_map[event_id]
 
@@ -323,6 +350,9 @@ class StateStorageController:
             groups, state_filter or StateFilter.all()
         )
 
+    @trace
+    @tag_args
+    @cancellable
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -334,6 +364,10 @@ class StateStorageController:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
                state at these events.
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)
@@ -346,7 +380,7 @@ class StateStorageController:
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
@@ -367,6 +401,7 @@ class StateStorageController:
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
 
+    @cancellable
     async def get_current_state_ids(
         self,
         room_id: str,
@@ -460,6 +495,7 @@ class StateStorageController:
             prev_stream_id, max_stream_id
         )
 
+    @trace
     async def get_current_state(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
@@ -487,9 +523,21 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
         """Get current hosts in room based on current state."""
 
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_current_hosts_in_room(room_id)
+
+    async def get_users_in_room_with_profiles(
+        self, room_id: str
+    ) -> Dict[str, ProfileInfo]:
+        """
+        Get the current users in the room with their profiles.
+        If the room is currently partial-stated, this will block until the room has
+        full state.
+        """
+        await self._partial_state_room_tracker.await_full_state(room_id)
+
+        return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e21ab08515..e881bff7fb 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -38,7 +39,7 @@ from typing import (
 )
 
 import attr
-from prometheus_client import Histogram
+from prometheus_client import Counter, Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
@@ -75,7 +76,8 @@ perf_logger = logging.getLogger("synapse.storage.TIME")
 sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
 
 sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+sql_txn_count = Counter("synapse_storage_transaction_time_count", "sec", ["desc"])
+sql_txn_duration = Counter("synapse_storage_transaction_time_sum", "sec", ["desc"])
 
 
 # Unique indexes which have been added in background updates. Maps from table name
@@ -168,6 +170,7 @@ class LoggingDatabaseConnection:
         *,
         txn_name: Optional[str] = None,
         after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
         exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
@@ -178,6 +181,7 @@ class LoggingDatabaseConnection:
             name=txn_name,
             database_engine=self.engine,
             after_callbacks=after_callbacks,
+            async_after_callbacks=async_after_callbacks,
             exception_callbacks=exception_callbacks,
         )
 
@@ -209,6 +213,9 @@ class LoggingDatabaseConnection:
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+    Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -227,6 +234,10 @@ class LoggingTransaction:
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
+        async_after_callbacks: A list that asynchronous callbacks will be appended
+            to by `async_call_after` which should run, before after_callbacks, on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
         exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +249,7 @@ class LoggingTransaction:
         "name",
         "database_engine",
         "after_callbacks",
+        "async_after_callbacks",
         "exception_callbacks",
     ]
 
@@ -247,12 +259,14 @@ class LoggingTransaction:
         name: str,
         database_engine: BaseDatabaseEngine,
         after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
         exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
         self.txn = txn
         self.name = name
         self.database_engine = database_engine
         self.after_callbacks = after_callbacks
+        self.async_after_callbacks = async_after_callbacks
         self.exception_callbacks = exception_callbacks
 
     def call_after(
@@ -277,6 +291,28 @@ class LoggingTransaction:
         # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
         self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
+    def async_call_after(
+        self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
+        """Call the given asynchronous callback on the main twisted thread after
+        the transaction has finished (but before those added in `call_after`).
+
+        Mostly used to invalidate remote caches after transactions.
+
+        Note that transactions may be retried a few times if they encounter database
+        errors such as serialization failures. Callbacks given to `async_call_after`
+        will accumulate across transaction attempts and will _all_ be called once a
+        transaction attempt succeeds, regardless of whether previous transaction
+        attempts failed. Otherwise, if all transaction attempts fail, all
+        `call_on_exception` callbacks will be run instead.
+        """
+        # if self.async_after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.async_after_callbacks is not None
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.async_after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
+
     def call_on_exception(
         self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
     ) -> None:
@@ -497,15 +533,14 @@ class DatabasePool:
         if isinstance(self.engine, Sqlite3Engine):
             self._unsafe_to_upsert_tables.add("user_directory_search")
 
-        if self.engine.can_native_upsert:
-            # Check ASAP (and then later, every 1s) to see if we have finished
-            # background updates of tables that aren't safe to update.
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
+        # Check ASAP (and then later, every 1s) to see if we have finished
+        # background updates of tables that aren't safe to update.
+        self._clock.call_later(
+            0.0,
+            run_as_background_process,
+            "upsert_safety_check",
+            self._check_safe_to_upsert,
+        )
 
     def name(self) -> str:
         "Return the name of this database"
@@ -574,6 +609,7 @@ class DatabasePool:
         conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
+        async_after_callbacks: List[_AsyncCallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
         func: Callable[Concatenate[LoggingTransaction, P], R],
         *args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
             conn
             desc
             after_callbacks
+            async_after_callbacks
             exception_callbacks
             func
             *args
@@ -659,6 +696,7 @@ class DatabasePool:
                 cursor = conn.cursor(
                     txn_name=name,
                     after_callbacks=after_callbacks,
+                    async_after_callbacks=async_after_callbacks,
                     exception_callbacks=exception_callbacks,
                 )
                 try:
@@ -757,7 +795,8 @@ class DatabasePool:
 
             self._current_txn_total_time += duration
             self._txn_perf_counters.update(desc, duration)
-            sql_txn_timer.labels(desc).observe(duration)
+            sql_txn_count.labels(desc).inc(1)
+            sql_txn_duration.labels(desc).inc(duration)
 
     async def runInteraction(
         self,
@@ -798,6 +837,7 @@ class DatabasePool:
 
         async def _runInteraction() -> R:
             after_callbacks: List[_CallbackListEntry] = []
+            async_after_callbacks: List[_AsyncCallbackListEntry] = []
             exception_callbacks: List[_CallbackListEntry] = []
 
             if not current_context():
@@ -809,6 +849,7 @@ class DatabasePool:
                         self.new_transaction,
                         desc,
                         after_callbacks,
+                        async_after_callbacks,
                         exception_callbacks,
                         func,
                         *args,
@@ -817,13 +858,17 @@ class DatabasePool:
                         **kwargs,
                     )
 
+                # We order these assuming that async functions call out to external
+                # systems (e.g. to invalidate a cache) and the sync functions make these
+                # changes on any local in-memory caches/similar, and thus must be second.
+                for async_callback, async_args, async_kwargs in async_after_callbacks:
+                    await async_callback(*async_args, **async_kwargs)
                 for after_callback, after_args, after_kwargs in after_callbacks:
                     after_callback(*after_args, **after_kwargs)
-
                 return cast(R, result)
             except Exception:
-                for after_callback, after_args, after_kwargs in exception_callbacks:
-                    after_callback(*after_args, **after_kwargs)
+                for exception_callback, after_args, after_kwargs in exception_callbacks:
+                    exception_callback(*after_args, **after_kwargs)
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and
@@ -1114,11 +1159,8 @@ class DatabasePool:
         attempts = 0
         while True:
             try:
-                # We can autocommit if we are going to use native upserts
-                autocommit = (
-                    self.engine.can_native_upsert
-                    and table not in self._unsafe_to_upsert_tables
-                )
+                # We can autocommit if it is safe to upsert
+                autocommit = table not in self._unsafe_to_upsert_tables
 
                 return await self.runInteraction(
                     desc,
@@ -1153,7 +1195,7 @@ class DatabasePool:
     ) -> bool:
         """
         Pick the UPSERT method which works best on the platform. Either the
-        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+        native one (Pg9.5+, SQLite >= 3.24), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
@@ -1161,14 +1203,15 @@ class DatabasePool:
             keyvalues: The unique key tables and their new values
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
-            lock: True to lock the table when doing the upsert.
+            lock: True to lock the table when doing the upsert. Unused when performing
+                a native upsert.
         Returns:
             Returns True if a row was inserted or updated (i.e. if `values` is
             not empty then this always returns True)
         """
         insertion_values = insertion_values or {}
 
-        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+        if table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
@@ -1319,14 +1362,12 @@ class DatabasePool:
             value_names: The value column names
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
-            lock: True to lock the table when doing the upsert. Unused if the database engine
-                supports native upserts.
+            lock: True to lock the table when doing the upsert. Unused when performing
+                a native upsert.
         """
 
-        # We can autocommit if we are going to use native upserts
-        autocommit = (
-            self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
-        )
+        # We can autocommit if it safe to upsert
+        autocommit = table not in self._unsafe_to_upsert_tables
 
         await self.runInteraction(
             desc,
@@ -1360,10 +1401,10 @@ class DatabasePool:
             value_names: The value column names
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
-            lock: True to lock the table when doing the upsert. Unused if the database engine
-                supports native upserts.
+            lock: True to lock the table when doing the upsert. Unused when performing
+                a native upsert.
         """
-        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+        if table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
                 txn, table, key_names, key_values, value_names, value_values
             )
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3d31d3737..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -149,31 +149,6 @@ class DataStore(
             ],
         )
 
-        self._cache_id_gen: Optional[MultiWriterIdGenerator]
-        if isinstance(self.database_engine, PostgresEngine):
-            # We set the `writers` to an empty list here as we don't care about
-            # missing updates over restarts, as we'll not have anything in our
-            # caches to invalidate. (This reduces the amount of writes to the DB
-            # that happen).
-            self._cache_id_gen = MultiWriterIdGenerator(
-                db_conn,
-                database,
-                stream_name="caches",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    (
-                        "cache_invalidation_stream_by_instance",
-                        "instance_name",
-                        "stream_id",
-                    )
-                ],
-                sequence_name="cache_invalidation_stream_seq",
-                writers=[],
-            )
-
-        else:
-            self._cache_id_gen = None
-
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room, (user_id,)
         )
         self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_push_rules_enabled_for_user, (user_id,)
-        )
         # This user might be contained in the ignored_by cache for other users,
         # so we have to invalidate it all.
         self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
             device_list_summary=DeviceListUpdates(),
         )
 
-    async def set_appservice_last_pos(self, pos: int) -> None:
-        def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
-            )
+    async def get_appservice_last_pos(self) -> int:
+        """
+        Get the last stream ordering position for the appservice process.
+        """
 
-        await self.db_pool.runInteraction(
-            "set_appservice_last_pos", set_appservice_last_pos_txn
+        return await self.db_pool.simple_select_one_onecol(
+            table="appservice_stream_position",
+            retcol="stream_ordering",
+            keyvalues={},
+            desc="get_appservice_last_pos",
         )
 
-    async def get_new_events_for_appservice(
-        self, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
-        """Get all new events for an appservice"""
-
-        def get_new_events_for_appservice_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
-            sql = (
-                "SELECT e.stream_ordering, e.event_id"
-                " FROM events AS e"
-                " WHERE"
-                " (SELECT stream_ordering FROM appservice_stream_position)"
-                "     < e.stream_ordering"
-                " AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-                " LIMIT ?"
-            )
-
-            txn.execute(sql, (current_id, limit))
-            rows = txn.fetchall()
-
-            upper_bound = current_id
-            if len(rows) == limit:
-                upper_bound = rows[-1][0]
-
-            return upper_bound, [row[1] for row in rows]
+    async def set_appservice_last_pos(self, pos: int) -> None:
+        """
+        Set the last stream ordering position for the appservice process.
+        """
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
-            "get_new_events_for_appservice", get_new_events_for_appservice_txn
+        await self.db_pool.simple_update_one(
+            table="appservice_stream_position",
+            keyvalues={},
+            updatevalues={"stream_ordering": pos},
+            desc="set_appservice_last_pos",
         )
 
-        events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
-        return upper_bound, events
-
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             psql_only=True,  # The table is only on postgres DBs.
         )
 
+        self._cache_id_gen: Optional[MultiWriterIdGenerator]
+        if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
+                sequence_name="cache_invalidation_stream_seq",
+                writers=[],
+            )
+
+        else:
+            self._cache_id_gen = None
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -193,7 +219,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         relates_to: Optional[str],
         backfilled: bool,
     ) -> None:
-        self._invalidate_get_event_cache(event_id)
+        # This invalidates any local in-memory cached event objects, the original
+        # process triggering the invalidation is responsible for clearing any external
+        # cached objects.
+        self._invalidate_local_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +237,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
-            self._invalidate_get_event_cache(redacts)
+            self._invalidate_local_get_event_cache(redacts)
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
             self.get_relations_for_event.invalidate((redacts,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            self.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 422e0e65ca..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             (user_id, device_id), None
         )
 
-        set_tag("last_deleted_stream_id", last_deleted_stream_id)
+        set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
 
         if last_deleted_stream_id:
             has_changed = self._device_inbox_stream_cache.has_entity_changed(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index adde5d0978..5d700ca6c3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -668,8 +669,9 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
         ...
 
     @trace
+    @cancellable
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, str]]
+        self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
@@ -706,8 +708,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             else:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
 
-        set_tag("in_cache", results)
-        set_tag("not_in_cache", user_ids_not_in_cache)
+        set_tag("in_cache", str(results))
+        set_tag("not_in_cache", str(user_ids_not_in_cache))
 
         return user_ids_not_in_cache, results
 
@@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
+    @cancellable
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
@@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="get_min_device_lists_changes_in_room",
         )
 
+    @cancellable
     async def get_device_list_changes_in_rooms(
         self, room_ids: Collection[str], from_id: int
     ) -> Optional[Set[str]]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..8e9e1b0b4b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
     cast,
+    overload,
 )
 
 import attr
 from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
 
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.appservice import (
@@ -47,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -113,7 +117,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             user_devices = devices[user_id]
             results = []
             for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
+                result: JsonDict = {"device_id": device_id}
 
                 keys = device.keys
                 if keys:
@@ -132,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return now_stream_id, []
 
     @trace
+    @cancellable
     async def get_e2e_device_keys_for_cs_api(
         self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Dict[str, Dict[str, JsonDict]]:
@@ -143,7 +148,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key data.  The key data will be a dict in the same format as the
             DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
         """
-        set_tag("query_list", query_list)
+        set_tag("query_list", str(query_list))
         if not query_list:
             return {}
 
@@ -156,6 +161,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
                 r = device_info.keys
+                if r is None:
+                    continue
+
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
@@ -164,13 +172,43 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return rv
 
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[True],
+        include_deleted_devices: Literal[True],
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        ...
+
     @trace
+    @cancellable
     async def get_e2e_device_keys_and_signatures(
         self,
-        query_list: List[Tuple[str, Optional[str]]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
-    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+    ) -> Union[
+        Dict[str, Dict[str, DeviceKeyLookupResult]],
+        Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+    ]:
         """Fetch a list of device keys
 
         Any cross-signatures made on the keys by the owner of the device are also
@@ -383,7 +421,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
+            set_tag("new_keys", str(new_keys))
             # We are protected from race between lookup and insertion due to
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -852,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return keys
 
+    @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
     ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
@@ -867,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             keys were not found, either their user ID will not be in the dict,
             or their user ID will map to None.
         """
-
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
@@ -1044,7 +1082,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            row = await self.db_pool.runInteraction(
+            claim_row = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
@@ -1052,11 +1090,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 algorithm,
                 db_autocommit=db_autocommit,
             )
-            if row:
+            if claim_row:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[row[0]] = row[1]
+                device_results[claim_row[0]] = claim_row[1]
                 continue
 
             # No one-time key available, so see if there's a fallback
@@ -1126,7 +1164,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
+            set_tag("device_keys", str(device_keys))
 
             old_key_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..ef477978ed 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -47,6 +48,7 @@ from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -126,6 +128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(event_ids)
 
+    @trace
+    @tag_args
     async def get_auth_chain_ids(
         self,
         room_id: str,
@@ -709,6 +713,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -767,6 +773,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -970,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return int(min_depth) if min_depth is not None else None
 
+    @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
     ) -> List[str]:
@@ -1286,6 +1294,51 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return event_id_results
 
+    @trace
+    async def record_event_failed_pull_attempt(
+        self, room_id: str, event_id: str, cause: str
+    ) -> None:
+        """
+        Record when we fail to pull an event over federation.
+
+        This information allows us to be more intelligent when we decide to
+        retry (we don't need to fail over and over) and we can process that
+        event in the background so we don't block on it each time.
+
+        Args:
+            room_id: The room where the event failed to pull from
+            event_id: The event that failed to be fetched or processed
+            cause: The error message or reason that we failed to pull the event
+        """
+        await self.db_pool.runInteraction(
+            "record_event_failed_pull_attempt",
+            self._record_event_failed_pull_attempt_upsert_txn,
+            room_id,
+            event_id,
+            cause,
+            db_autocommit=True,  # Safe as it's a single upsert
+        )
+
+    def _record_event_failed_pull_attempt_upsert_txn(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        event_id: str,
+        cause: str,
+    ) -> None:
+        sql = """
+            INSERT INTO event_failed_pull_attempts (
+                room_id, event_id, num_attempts, last_attempt_ts, last_cause
+            )
+                VALUES (?, ?, ?, ?, ?)
+            ON CONFLICT (room_id, event_id) DO UPDATE SET
+                num_attempts=event_failed_pull_attempts.num_attempts + 1,
+                last_attempt_ts=EXCLUDED.last_attempt_ts,
+                last_cause=EXCLUDED.last_cause;
+        """
+
+        txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause))
+
     async def get_missing_events(
         self,
         room_id: str,
@@ -1339,6 +1392,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1375,6 +1430,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
@@ -1597,7 +1653,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 logger.info("Invalid prev_events for %s", event_id)
                 continue
 
-            if room_version.event_format == EventFormatVersions.V1:
+            if room_version.event_format == EventFormatVersions.ROOM_V1_V2:
                 for prev_event_tuple in prev_events:
                     if (
                         not isinstance(prev_event_tuple, list)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index dd2627037c..6b8668d2dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,14 +12,85 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+  1. Sending out push to a user's device; and
+  2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 
 from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -27,6 +98,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -93,7 +165,9 @@ class NotifCounts:
     highlight_count: int = 0
 
 
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+    actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -159,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             replaces_index="event_push_summary_user_rm",
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index2",
+            index_name="event_push_summary_unique_index2",
+            table="event_push_summary",
+            columns=["user_id", "room_id", "thread_id"],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            "event_push_backfill_thread_id",
+            self._background_backfill_thread_id,
+        )
+
+    async def _background_backfill_thread_id(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """
+        Fill in the thread_id field for event_push_actions and event_push_summary.
+
+        This is preparatory so that it can be made non-nullable in the future.
+
+        Because all current (null) data is done in an unthreaded manner this
+        simply assumes it is on the "main" timeline. Since event_push_actions
+        are periodically cleared it is not possible to correctly re-calculate
+        the thread_id.
+        """
+        event_push_actions_done = progress.get("event_push_actions_done", False)
+
+        def add_thread_id_txn(
+            txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+        ) -> int:
+            sql = f"""
+            SELECT stream_ordering
+            FROM {table_name}
+            WHERE
+                thread_id IS NULL
+                AND stream_ordering > ?
+            ORDER BY stream_ordering
+            LIMIT ?
+            """
+            txn.execute(sql, (start_stream_ordering, batch_size))
+
+            # No more rows to process.
+            rows = txn.fetchall()
+            if not rows:
+                progress[f"{table_name}_done"] = True
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "event_push_backfill_thread_id", progress
+                )
+                return 0
+
+            # Update the thread ID for any of those rows.
+            max_stream_ordering = rows[-1][0]
+
+            sql = f"""
+            UPDATE {table_name}
+            SET thread_id = 'main'
+            WHERE stream_ordering <= ? AND thread_id IS NULL
+            """
+            txn.execute(sql, (max_stream_ordering,))
+
+            # Update progress.
+            processed_rows = txn.rowcount
+            progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "event_push_backfill_thread_id", progress
+            )
+
+            return processed_rows
+
+        # First update the event_push_actions table, then the event_push_summary table.
+        #
+        # Note that the event_push_actions_staging table is ignored since it is
+        # assumed that items in that table will only exist for a short period of
+        # time.
+        if not event_push_actions_done:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_txn,
+                "event_push_actions",
+                progress.get("max_event_push_actions_stream_ordering", 0),
+            )
+        else:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_txn,
+                "event_push_summary",
+                progress.get("max_event_push_summary_stream_ordering", 0),
+            )
+
+            # Only done after the event_push_summary table is done.
+            if not result:
+                await self.db_pool.updates._end_background_update(
+                    "event_push_backfill_thread_id"
+                )
+
+        return result
+
     @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
@@ -166,7 +338,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         user_id: str,
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
-        for a given user in a given room after the given read receipt.
+        for a given user in a given room after their latest read receipt.
 
         Note that this function assumes the user to be a current member of the room,
         since it's either called by the sync handler to handle joined room entries, or by
@@ -177,9 +349,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id: The user to retrieve the counts for.
 
         Returns
-            A dict containing the counts mentioned earlier in this docstring,
-            respectively under the keys "notify_count", "highlight_count" and
-            "unread_count".
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
@@ -194,20 +365,22 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         room_id: str,
         user_id: str,
     ) -> NotifCounts:
+        # Get the stream ordering of the user's latest receipt in the room.
         result = self.get_last_receipt_for_user_txn(
             txn,
             user_id,
             room_id,
-            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+            receipt_types=(
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ),
         )
 
-        stream_ordering = None
         if result:
             _, stream_ordering = result
 
-        if stream_ordering is None:
-            # Either last_read_event_id is None, or it's an event we don't have (e.g.
-            # because it's been purged), in which case retrieve the stream ordering for
+        else:
+            # If the user has no receipts in the room, retrieve the stream ordering for
             # the latest membership event from this user in this room (which we assume is
             # a join).
             event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -224,10 +397,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
     def _get_unread_counts_by_pos_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        receipt_stream_ordering: int,
     ) -> NotifCounts:
         """Get the number of unread messages for a user/room that have happened
         since the given stream ordering.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            receipt_stream_ordering: The stream ordering of the user's latest
+                receipt in the room. If there are no receipts, the stream ordering
+                of the user's join event.
+
+        Returns
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
 
         counts = NotifCounts()
@@ -255,7 +444,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     OR last_receipt_stream_ordering = ?
                 )
             """,
-            (room_id, user_id, stream_ordering, stream_ordering),
+            (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
         )
         row = txn.fetchone()
 
@@ -265,7 +454,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             counts.notify_count += row[1]
             counts.unread_count += row[2]
 
-        # Next we need to count highlights, which aren't summarized
+        # Next we need to count highlights, which aren't summarised
         sql = """
             SELECT COUNT(*) FROM event_push_actions
             WHERE user_id = ?
@@ -273,17 +462,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 AND stream_ordering > ?
                 AND highlight = 1
         """
-        txn.execute(sql, (user_id, room_id, stream_ordering))
+        txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
         row = txn.fetchone()
         if row:
             counts.highlight_count += row[0]
 
         # Finally we need to count push actions that aren't included in the
-        # summary returned above, e.g. recent events that haven't been
-        # summarized yet, or the summary is empty due to a recent read receipt.
-        stream_ordering = max(stream_ordering, summary_stream_ordering)
+        # summary returned above. This might be due to recent events that haven't
+        # been summarised yet or the summary is out of date due to a recent read
+        # receipt.
+        start_unread_stream_ordering = max(
+            receipt_stream_ordering, summary_stream_ordering
+        )
         notify_count, unread_count = self._get_notif_unread_count_for_user_room(
-            txn, room_id, user_id, stream_ordering
+            txn, room_id, user_id, start_unread_stream_ordering
         )
 
         counts.notify_count += notify_count
@@ -304,6 +496,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         Does not consult `event_push_summary` table, which may include push
         actions that have been deleted from `event_push_actions` table.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            stream_ordering: The (exclusive) minimum stream ordering to consider.
+            max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+                If this is not given, then no maximum is applied.
+
+        Return:
+            A tuple of the notif count and unread count in the given range.
         """
 
         # If there have been no events in the room since the stream ordering,
@@ -354,6 +557,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
+    def _get_receipts_by_room_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> List[Tuple[str, int]]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ),
+        )
+
+        sql = f"""
+            SELECT room_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+            AND user_id = ?
+            GROUP BY room_id
+        """
+
+        args.extend((user_id,))
+        txn.execute(sql, args)
+        return cast(List[Tuple[str, int]], txn.fetchall())
+
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
         user_id: str,
@@ -376,81 +604,46 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
-        # find rooms that have a read receipt in them and return the next
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
-            # find rooms that have a read receipt in them and return the next
-            # push actions
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_http_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+                FROM event_push_actions AS ep
+                WHERE
+                    ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
         notifs = [
             HttpPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -485,82 +678,50 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
-        # find rooms that have a read receipt in them and return the most recent
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  ep.highlight, e.received_ts"
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_email_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight, e.received_ts"
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
-            )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
-            txn.execute(sql, args)
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
-                received_ts=row[5],
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
+                received_ts=received_ts,
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -606,8 +767,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
     async def add_push_actions_to_staging(
         self,
         event_id: str,
-        user_id_actions: Dict[str, List[Union[dict, str]]],
+        user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
         count_as_unread: bool,
+        thread_id: str,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -616,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
             count_as_unread: Whether this event should increment unread counts.
+            thread_id: The thread this event is parent of, if applicable.
         """
         if not user_id_actions:
             return
@@ -623,8 +786,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
-            user_id: str, actions: List[Union[dict, str]]
-        ) -> Tuple[str, str, str, int, int, int]:
+            user_id: str, actions: Collection[Union[Mapping, str]]
+        ) -> Tuple[str, str, str, int, int, int, str]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
@@ -634,28 +797,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
+                thread_id,  # thread_id column
             )
 
-        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
-            # We don't use simple_insert_many here to avoid the overhead
-            # of generating lists of dicts.
-
-            sql = """
-                INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            txn.execute_batch(
-                sql,
-                (
-                    _gen_entry(user_id, actions)
-                    for user_id, actions in user_id_actions.items()
-                ),
-            )
-
-        return await self.db_pool.runInteraction(
-            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        await self.db_pool.simple_insert_many(
+            "event_push_actions_staging",
+            keys=(
+                "event_id",
+                "user_id",
+                "actions",
+                "notif",
+                "highlight",
+                "unread",
+                "thread_id",
+            ),
+            values=[
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.items()
+            ],
+            desc="add_push_actions_to_staging",
         )
 
     async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -769,12 +929,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # [10, <none>, 20], we should treat this as being equivalent to
         # [10, 10, 20].
         #
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering <= ?"
-            " ORDER BY stream_ordering DESC"
-            " LIMIT 1"
-        )
+        sql = """
+            SELECT received_ts FROM events
+            WHERE stream_ordering <= ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
+        """
 
         while range_end - range_start > 0:
             middle = (range_end + range_start) // 2
@@ -802,14 +962,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         self, stream_ordering: int
     ) -> Optional[int]:
         def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ? AND notif = 1"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
+            sql = """
+                SELECT e.received_ts
+                FROM event_push_actions AS ep
+                JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+                WHERE ep.stream_ordering > ? AND notif = 1
+                ORDER BY ep.stream_ordering ASC
+                LIMIT 1
+            """
             txn.execute(sql, (stream_ordering,))
             return cast(Optional[Tuple[int]], txn.fetchone())
 
@@ -858,10 +1018,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         Any push actions which predate the user's most recent read receipt are
         now redundant, so we can remove them from `event_push_actions` and
         update `event_push_summary`.
+
+        Returns true if all new receipts have been processed.
         """
 
         limit = 100
 
+        # The (inclusive) receipt stream ID that was previously processed..
         min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_last_receipt_stream_id",
@@ -871,6 +1034,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         max_receipts_stream_id = self._receipts_id_gen.get_current_token()
 
+        # The (inclusive) event stream ordering that was previously summarised.
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
         sql = """
             SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
             FROM receipts_linearized AS r
@@ -895,13 +1066,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
         rows = txn.fetchall()
 
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
-
         # For each new read receipt we delete push actions from before it and
         # recalculate the summary.
         for _, room_id, user_id, stream_ordering in rows:
@@ -920,10 +1084,15 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 (room_id, user_id, stream_ordering),
             )
 
+            # Fetch the notification counts between the stream ordering of the
+            # latest receipt and what was previously summarised.
             notif_count, unread_count = self._get_notif_unread_count_for_user_room(
                 txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
             )
 
+            # Replace the previous summary with the new counts.
+            #
+            # TODO(threads): Upsert per-thread instead of setting them all to main.
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="event_push_summary",
@@ -933,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     "unread_count": unread_count,
                     "stream_ordering": old_rotate_stream_ordering,
                     "last_receipt_stream_ordering": stream_ordering,
+                    "thread_id": "main",
                 },
             )
 
@@ -956,10 +1126,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         return len(rows) < limit
 
     def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
-        """Archives older notifications into event_push_summary. Returns whether
-        the archiving process has caught up or not.
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Returns whether the archiving process has caught up or not.
         """
 
+        # The (inclusive) event stream ordering that was previously summarised.
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -974,7 +1146,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
             ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-        """,
+            """,
             (old_rotate_stream_ordering, self._rotate_count),
         )
         stream_row = txn.fetchone()
@@ -993,19 +1165,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
-        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+        self._rotate_notifs_before_txn(
+            txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+        )
 
         return caught_up
 
     def _rotate_notifs_before_txn(
-        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        old_rotate_stream_ordering: int,
+        rotate_to_stream_ordering: int,
     ) -> None:
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+        rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+        table.
+
+        Args:
+            txn: The database transaction.
+            old_rotate_stream_ordering: The previous maximum event stream ordering.
+            rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+        """
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
@@ -1069,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
 
+        # TODO(threads): Update on a per-thread basis.
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",
             key_names=("user_id", "room_id"),
             key_values=[(user_id, room_id) for user_id, room_id in summaries],
-            value_names=("notif_count", "unread_count", "stream_ordering"),
+            value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
             value_values=[
                 (
                     summary.notif_count,
                     summary.unread_count,
                     summary.stream_ordering,
+                    "main",
                 )
                 for summary in summaries.values()
             ],
@@ -1090,12 +1274,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             (rotate_to_stream_ordering,),
         )
 
-    async def _remove_old_push_actions_that_have_rotated(
-        self,
-    ) -> None:
-        """Clear out old push actions that have been summarized."""
+    async def _remove_old_push_actions_that_have_rotated(self) -> None:
+        """Clear out old push actions that have been summarised."""
 
-        # We want to clear out anything that older than a day that *has* already
+        # We want to clear out anything that is older than a day that *has* already
         # been rotated.
         rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
             table="event_push_summary_stream_ordering",
@@ -1119,7 +1301,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 SELECT stream_ordering FROM event_push_actions
                 WHERE stream_ordering <= ? AND highlight = 0
                 ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-            """,
+                """,
                 (
                     max_stream_ordering_to_delete,
                     batch_size,
@@ -1188,7 +1370,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_actions",
             columns=["highlight", "stream_ordering"],
             where_clause="highlight=0",
-            psql_only=True,
         )
 
     async def get_push_actions_for_user(
@@ -1215,16 +1396,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
             # NB. This assumes event_ids are globally unique since
             # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
+            sql = """
+                SELECT epa.event_id, epa.room_id,
+                    epa.stream_ordering, epa.topological_ordering,
+                    epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+                FROM event_push_actions epa, events e
+                WHERE epa.event_id = e.event_id
+                    AND epa.user_id = ? %s
+                    AND epa.notif = 1
+                ORDER BY epa.stream_ordering DESC
+                LIMIT ?
+            """ % (
+                before_clause,
             )
             txn.execute(sql, args)
             return cast(
@@ -1247,7 +1430,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         ]
 
 
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
     for action in actions:
         if not isinstance(action, dict):
             continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index eb4efbb93c..1b54a2eb57 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1293,7 +1295,7 @@ class PersistEventsStore:
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1346,9 +1348,24 @@ class PersistEventsStore:
             event_id: outlier for event_id, outlier in txn
         }
 
+        logger.debug(
+            "_update_outliers_txn: events=%s have_persisted=%s",
+            [ev.event_id for ev, _ in events_and_contexts],
+            have_persisted,
+        )
+
         to_remove = set()
         for event, context in events_and_contexts:
-            if event.event_id not in have_persisted:
+            outlier_persisted = have_persisted.get(event.event_id)
+            logger.debug(
+                "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s",
+                event.event_id,
+                event.internal_metadata.is_outlier(),
+                outlier_persisted,
+            )
+
+            # Ignore events which we haven't persisted at all
+            if outlier_persisted is None:
                 continue
 
             to_remove.add(event)
@@ -1358,7 +1375,6 @@ class PersistEventsStore:
                 # was an outlier or not - what we have is at least as good.
                 continue
 
-            outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
                 # an outlier in the database. We now have some state at that event
@@ -1369,7 +1385,10 @@ class PersistEventsStore:
                 # events down /sync. In general they will be historical events, so that
                 # doesn't matter too much, but that is not always the case.
 
-                logger.info("Updating state for ex-outlier event %s", event.event_id)
+                logger.info(
+                    "_update_outliers_txn: Updating state for ex-outlier event %s",
+                    event.event_id,
+                )
 
                 # insert into event_to_state_groups.
                 try:
@@ -1473,7 +1492,7 @@ class PersistEventsStore:
                     event.sender,
                     "url" in event.content and isinstance(event.content["url"], str),
                     event.get_state_key(),
-                    context.rejected or None,
+                    context.rejected,
                 )
                 for event, context in events_and_contexts
             ),
@@ -1669,13 +1688,13 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill() -> None:
+        async def prefill() -> None:
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set(
+                await self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
                 )
 
-        txn.call_after(prefill)
+        txn.async_call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Invalidate the caches for the redacted event.
@@ -1684,7 +1703,7 @@ class PersistEventsStore:
         _invalidate_caches_for_event.
         """
         assert event.redacts is not None
-        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
@@ -2173,9 +2192,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight, unread
+                topological_ordering, notif, highlight, unread, thread_id
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
@@ -2416,17 +2435,31 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
+        backward_extremity_tuples_to_remove = [
+            (ev.event_id, ev.room_id)
+            for ev in events
+            if not ev.internal_metadata.is_outlier()
+            # If we encountered an event with no prev_events, then we might
+            # as well remove it now because it won't ever have anything else
+            # to backfill from.
+            or len(ev.prev_event_ids()) == 0
+        ]
         txn.execute_batch(
             query,
-            [
-                (ev.event_id, ev.room_id)
-                for ev in events
-                if not ev.internal_metadata.is_outlier()
-                # If we encountered an event with no prev_events, then we might
-                # as well remove it now because it won't ever have anything else
-                # to backfill from.
-                or len(ev.prev_event_ids()) == 0
-            ],
+            backward_extremity_tuples_to_remove,
+        )
+
+        # Clear out the failed backfill attempts after we successfully pulled
+        # the event. Since we no longer need these events as backward
+        # extremities, it also means that they won't be backfilled from again so
+        # we no longer need to store the backfill attempts around it.
+        query = """
+            DELETE FROM event_failed_pull_attempts
+            WHERE event_id = ? and room_id = ?
+        """
+        txn.execute_batch(
+            query,
+            backward_extremity_tuples_to_remove,
         )
 
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index eeca85fc94..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -67,6 +67,8 @@ class _BackgroundUpdates:
     EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
     EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
 
+    EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _CalculateChainCover:
@@ -253,6 +255,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             replaces_index="ev_edges_id",
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+            self._background_events_populate_state_key_rejections,
+        )
+
     async def _background_reindex_fields_sender(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -1399,3 +1406,83 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             )
 
         return batch_size
+
+    async def _background_events_populate_state_key_rejections(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Back-populate `events.state_key` and `events.rejection_reason"""
+
+        min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+        max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+        def _populate_txn(txn: LoggingTransaction) -> bool:
+            """Returns True if we're done."""
+
+            # first we need to find an endpoint.
+            # we need to find the final row in the batch of batch_size, which means
+            # we need to skip over (batch_size-1) rows and get the next row.
+            txn.execute(
+                """
+                SELECT stream_ordering FROM events
+                WHERE stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering
+                LIMIT 1 OFFSET ?
+                """,
+                (
+                    min_stream_ordering_exclusive,
+                    max_stream_ordering_inclusive,
+                    batch_size - 1,
+                ),
+            )
+
+            endpoint = None
+            row = txn.fetchone()
+            if row:
+                endpoint = row[0]
+
+            where_clause = "stream_ordering > ?"
+            args = [min_stream_ordering_exclusive]
+            if endpoint:
+                where_clause += " AND stream_ordering <= ?"
+                args.append(endpoint)
+
+            # now do the updates.
+            txn.execute(
+                f"""
+                UPDATE events
+                SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+                    rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+                WHERE ({where_clause})
+                """,
+                args,
+            )
+
+            logger.info(
+                "populated new `events` columns up to %s/%i: updated %i rows",
+                endpoint,
+                max_stream_ordering_inclusive,
+                txn.rowcount,
+            )
+
+            if endpoint is None:
+                # we're done
+                return True
+
+            progress["min_stream_ordering_exclusive"] = endpoint
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+                progress,
+            )
+            return False
+
+        done = await self.db_pool.runInteraction(
+            desc="events_populate_state_key_rejections", func=_populate_txn
+        )
+
+        if done:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+            )
+
+        return batch_size
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4d916b240e..4f1dd6d6b6 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -79,7 +80,8 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -238,7 +240,9 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+        self._get_event_cache: AsyncLruCache[
+            Tuple[str], EventCacheEntry
+        ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -292,25 +296,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def get_received_ts(self, event_id: str) -> Optional[int]:
-        """Get received_ts (when it was persisted) for the event.
-
-        Raises an exception for unknown events.
-
-        Args:
-            event_id: The event ID to query.
-
-        Returns:
-            Timestamp in milliseconds, or None for events that were persisted
-            before received_ts was implemented.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events",
-            keyvalues={"event_id": event_id},
-            retcol="received_ts",
-            desc="get_received_ts",
-        )
-
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
@@ -355,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore):
     ) -> Optional[EventBase]:
         ...
 
+    @cancellable
     async def get_event(
         self,
         event_id: str,
@@ -447,6 +433,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
+    @cancellable
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -598,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
+    @cancellable
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
     ) -> Dict[str, EventCacheEntry]:
@@ -617,7 +607,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        # Shortcut: check if we have any events in the *in memory* cache - this function
+        # may be called repeatedly for the same event so at this point we cannot reach
+        # out to any external cache for performance reasons. The external cache is
+        # checked later on in the `get_missing_events_from_cache_or_db` function below.
+        event_entry_map = self._get_events_from_local_cache(
             event_ids,
         )
 
@@ -649,7 +643,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         if missing_events_ids:
 
-            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+            async def get_missing_events_from_cache_or_db() -> Dict[
+                str, EventCacheEntry
+            ]:
                 """Fetches the events in `missing_event_ids` from the database.
 
                 Also creates entries in `self._current_event_fetches` to allow
@@ -674,10 +670,18 @@ class EventsWorkerStore(SQLBaseStore):
                 # the events have been redacted, and if so pulling the redaction event
                 # out of the database to check it.
                 #
+                missing_events = {}
                 try:
-                    missing_events = await self._get_events_from_db(
+                    # Try to fetch from any external cache. We already checked the
+                    # in-memory cache above.
+                    missing_events = await self._get_events_from_external_cache(
                         missing_events_ids,
                     )
+                    # Now actually fetch any remaining events from the DB
+                    db_missing_events = await self._get_events_from_db(
+                        missing_events_ids - missing_events.keys(),
+                    )
+                    missing_events.update(db_missing_events)
                 except Exception as e:
                     with PreserveLoggingContext():
                         fetching_deferred.errback(e)
@@ -696,7 +700,7 @@ class EventsWorkerStore(SQLBaseStore):
             # cancellations, since multiple `_get_events_from_cache_or_db` calls can
             # reuse the same fetch.
             missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
-                get_missing_events_from_db()
+                get_missing_events_from_cache_or_db()
             )
             event_entry_map.update(missing_events)
 
@@ -729,15 +733,96 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id: str) -> None:
-        self._get_event_cache.invalidate((event_id,))
+    def invalidate_get_event_cache_after_txn(
+        self, txn: LoggingTransaction, event_id: str
+    ) -> None:
+        """
+        Prepares a database transaction to invalidate the get event cache for a given
+        event ID when executed successfully. This is achieved by attaching two callbacks
+        to the transaction, one to invalidate the async cache and one for the in memory
+        sync cache (importantly called in that order).
+
+        Arguments:
+            txn: the database transaction to attach the callbacks to
+            event_id: the event ID to be invalidated from caches
+        """
+
+        txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+        txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+    async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in the asyncronous get event cache, which may be remote.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        await self._get_event_cache.invalidate((event_id,))
+
+    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in local in-memory get event caches.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        self._get_event_cache.invalidate_local((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _get_events_from_cache(
+    async def _get_events_from_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from the caches, both in memory and any external.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = self._get_events_from_local_cache(
+            events, update_metrics=update_metrics
+        )
+
+        missing_event_ids = (e for e in events if e not in event_map)
+        event_map.update(
+            await self._get_events_from_external_cache(
+                events=missing_event_ids,
+                update_metrics=update_metrics,
+            )
+        )
+
+        return event_map
+
+    async def _get_events_from_external_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from any configured external cache.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = await self._get_event_cache.get_external(
+                (event_id,), None, update_metrics=update_metrics
+            )
+            if ret:
+                event_map[event_id] = ret
+
+        return event_map
+
+    def _get_events_from_local_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
-        """Fetch events from the caches.
+        """Fetch events from the local, in memory, caches.
 
         May return rejected events.
 
@@ -749,7 +834,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = self._get_event_cache.get(
+            ret = self._get_event_cache.get_local(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -771,7 +856,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                self._get_event_cache.set((event_id,), cache_entry)
+                self._get_event_cache.set_local((event_id,), cache_entry)
 
         return event_map
 
@@ -965,7 +1050,13 @@ class EventsWorkerStore(SQLBaseStore):
                 }
 
                 row_dict = self.db_pool.new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+                    conn,
+                    "do_fetch",
+                    [],
+                    [],
+                    [],
+                    self._fetch_event_rows,
+                    events_to_fetch,
                 )
 
                 # We only want to resolve deferreds from the main thread
@@ -1006,23 +1097,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1050,7 +1160,7 @@ class EventsWorkerStore(SQLBaseStore):
             if format_version is None:
                 # This means that we stored the event before we had the concept
                 # of a event format version, so it must be a V1 event.
-                format_version = EventFormatVersions.V1
+                format_version = EventFormatVersions.ROOM_V1_V2
 
             room_version_id = row.room_version_id
 
@@ -1080,10 +1190,10 @@ class EventsWorkerStore(SQLBaseStore):
                 #
                 # So, the following approximations should be adequate.
 
-                if format_version == EventFormatVersions.V1:
+                if format_version == EventFormatVersions.ROOM_V1_V2:
                     # if it's event format v1 then it must be room v1 or v2
                     room_version = RoomVersions.V1
-                elif format_version == EventFormatVersions.V2:
+                elif format_version == EventFormatVersions.ROOM_V3:
                     # if it's event format v2 then it must be room v3
                     room_version = RoomVersions.V3
                 else:
@@ -1148,7 +1258,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.set((event_id,), cache_entry)
+            await self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
             if not redacted_event:
@@ -1340,6 +1450,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
@@ -1383,7 +1495,9 @@ class EventsWorkerStore(SQLBaseStore):
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+            (rid, eid)
+            for (rid, eid) in keys
+            if await self._get_event_cache.contains((eid,))
         }
         results = dict.fromkeys(cache_results, True)
         remaining = [k for k in keys if k not in cache_results]
@@ -1467,7 +1581,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1483,10 +1597,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_all_new_forward_event_rows(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1500,7 +1615,8 @@ class EventsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1509,7 +1625,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1524,11 +1640,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_ex_outlier_stream_rows_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
+                # NB: the next line (inner join) is what makes this query different from
+                # get_all_new_forward_event_rows.
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1543,7 +1662,8 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (last_id, current_id, instance_name))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1997,7 +2117,14 @@ class EventsWorkerStore(SQLBaseStore):
                 AND room_id = ?
                 /* Make sure event is not rejected */
                 AND rejections.event_id IS NULL
-            ORDER BY origin_server_ts %s
+            /**
+             * First sort by the message timestamp. If the message timestamps are the
+             * same, we want the message that logically comes "next" (before/after
+             * the given timestamp) based on the DAG and its topological order (`depth`).
+             * Finally, we can tie-break based on when it was received on the server
+             * (`stream_ordering`).
+             */
+            ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
             LIMIT 1;
         """
 
@@ -2016,7 +2143,8 @@ class EventsWorkerStore(SQLBaseStore):
                 order = "ASC"
 
             txn.execute(
-                sql_template % (comparison_operator, order), (timestamp, room_id)
+                sql_template % (comparison_operator, order, order, order),
+                (timestamp, room_id),
             )
             row = txn.fetchone()
             if row:
@@ -2081,14 +2209,92 @@ class EventsWorkerStore(SQLBaseStore):
     def _get_partial_state_events_batch_txn(
         txn: LoggingTransaction, room_id: str
     ) -> List[str]:
+        # we want to work through the events from oldest to newest, so
+        # we only want events whose prev_events do *not* have partial state - hence
+        # the 'NOT EXISTS' clause in the below.
+        #
+        # This is necessary because ordering by stream ordering isn't quite enough
+        # to ensure that we work from oldest to newest event (in particular,
+        # if an event is initially persisted as an outlier and later de-outliered,
+        # it can end up with a lower stream_ordering than its prev_events).
+        #
+        # Typically this means we'll only return one event per batch, but that's
+        # hard to do much about.
+        #
+        # See also: https://github.com/matrix-org/synapse/issues/13001
         txn.execute(
             """
             SELECT event_id FROM partial_state_events AS pse
                 JOIN events USING (event_id)
-            WHERE pse.room_id = ?
+            WHERE pse.room_id = ? AND
+               NOT EXISTS(
+                  SELECT 1 FROM event_edges AS ee
+                     JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+                     WHERE ee.event_id=pse.event_id
+               )
             ORDER BY events.stream_ordering
             LIMIT 100
             """,
             (room_id,),
         )
         return [row[0] for row in txn]
+
+    def mark_event_rejected_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        rejection_reason: Optional[str],
+    ) -> None:
+        """Mark an event that was previously accepted as rejected, or vice versa
+
+        This can happen, for example, when resyncing state during a faster join.
+
+        Args:
+            txn:
+            event_id: ID of event to update
+            rejection_reason: reason it has been rejected, or None if it is now accepted
+        """
+        if rejection_reason is None:
+            logger.info(
+                "Marking previously-processed event %s as accepted",
+                event_id,
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                "rejections",
+                keyvalues={"event_id": event_id},
+            )
+        else:
+            logger.info(
+                "Marking previously-processed event %s as rejected(%s)",
+                event_id,
+                rejection_reason,
+            )
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": event_id},
+                values={
+                    "reason": rejection_reason,
+                    "last_check": self._clock.time_msec(),
+                },
+            )
+        self.db_pool.simple_update_txn(
+            txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            updatevalues={"rejection_reason": rejection_reason},
+        )
+
+        self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+        # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+        #   call '_send_invalidation_to_replication', but we actually need the other
+        #   end to call _invalidate_local_get_event_cache() rather than (just)
+        #   _get_event_cache.invalidate().
+        #
+        #   One solution might be to (somehow) get the workers to call
+        #   _invalidate_caches_for_event() (though that will invalidate more than
+        #   strictly necessary).
+        #
+        #   https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 2d7633fbd5..7270ef09da 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -129,91 +129,48 @@ class LockStore(SQLBaseStore):
         now = self._clock.time_msec()
         token = random_string(6)
 
-        if self.db_pool.engine.can_native_upsert:
-
-            def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
-                # We take out the lock if either a) there is no row for the lock
-                # already, b) the existing row has timed out, or c) the row is
-                # for this instance (which means the process got killed and
-                # restarted)
-                sql = """
-                    INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
-                    VALUES (?, ?, ?, ?, ?)
-                    ON CONFLICT (lock_name, lock_key)
-                    DO UPDATE
-                        SET
-                            token = EXCLUDED.token,
-                            instance_name = EXCLUDED.instance_name,
-                            last_renewed_ts = EXCLUDED.last_renewed_ts
-                        WHERE
-                            worker_locks.last_renewed_ts < ?
-                            OR worker_locks.instance_name = EXCLUDED.instance_name
-                """
-                txn.execute(
-                    sql,
-                    (
-                        lock_name,
-                        lock_key,
-                        self._instance_name,
-                        token,
-                        now,
-                        now - _LOCK_TIMEOUT_MS,
-                    ),
-                )
-
-                # We only acquired the lock if we inserted or updated the table.
-                return bool(txn.rowcount)
-
-            did_lock = await self.db_pool.runInteraction(
-                "try_acquire_lock",
-                _try_acquire_lock_txn,
-                # We can autocommit here as we're executing a single query, this
-                # will avoid serialization errors.
-                db_autocommit=True,
+        def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
+            # We take out the lock if either a) there is no row for the lock
+            # already, b) the existing row has timed out, or c) the row is
+            # for this instance (which means the process got killed and
+            # restarted)
+            sql = """
+               INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
+               VALUES (?, ?, ?, ?, ?)
+               ON CONFLICT (lock_name, lock_key)
+               DO UPDATE
+                   SET
+                       token = EXCLUDED.token,
+                       instance_name = EXCLUDED.instance_name,
+                       last_renewed_ts = EXCLUDED.last_renewed_ts
+                   WHERE
+                       worker_locks.last_renewed_ts < ?
+                       OR worker_locks.instance_name = EXCLUDED.instance_name
+           """
+            txn.execute(
+                sql,
+                (
+                    lock_name,
+                    lock_key,
+                    self._instance_name,
+                    token,
+                    now,
+                    now - _LOCK_TIMEOUT_MS,
+                ),
             )
-            if not did_lock:
-                return None
-
-        else:
-            # If we're on an old SQLite we emulate the above logic by first
-            # clearing out any existing stale locks and then upserting.
-
-            def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool:
-                sql = """
-                    DELETE FROM worker_locks
-                    WHERE
-                        lock_name = ?
-                        AND lock_key = ?
-                        AND (last_renewed_ts < ? OR instance_name = ?)
-                """
-                txn.execute(
-                    sql,
-                    (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
-                )
-
-                inserted = self.db_pool.simple_upsert_txn_emulated(
-                    txn,
-                    table="worker_locks",
-                    keyvalues={
-                        "lock_name": lock_name,
-                        "lock_key": lock_key,
-                    },
-                    values={},
-                    insertion_values={
-                        "token": token,
-                        "last_renewed_ts": self._clock.time_msec(),
-                        "instance_name": self._instance_name,
-                    },
-                )
-
-                return inserted
 
-            did_lock = await self.db_pool.runInteraction(
-                "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn
-            )
+            # We only acquired the lock if we inserted or updated the table.
+            return bool(txn.rowcount)
 
-            if not did_lock:
-                return None
+        did_lock = await self.db_pool.runInteraction(
+            "try_acquire_lock",
+            _try_acquire_lock_txn,
+            # We can autocommit here as we're executing a single query, this
+            # will avoid serialization errors.
+            db_autocommit=True,
+        )
+        if not did_lock:
+            return None
 
         lock = Lock(
             self._reactor,
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
                 "initialise_mau_threepids",
                 [],
                 [],
+                [],
                 self._initialise_reserved_users,
                 hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
             )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 87b0d09039..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
@@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                self._invalidate_get_event_cache(event_id)
+                self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")
 
@@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         Returns:
             The list of state groups to delete.
         """
-        return await self.db_pool.runInteraction(
-            "purge_room", self._purge_room_txn, room_id
+
+        # This first runs the purge transaction with READ_COMMITTED isolation level,
+        # meaning any new rows in the tables will not trigger a serialization error.
+        # We then run the same purge a second time without this isolation level to
+        # purge any of those rows which were added during the first.
+
+        state_groups_to_delete = await self.db_pool.runInteraction(
+            "purge_room",
+            self._purge_room_txn,
+            room_id=room_id,
+            isolation_level=IsolationLevel.READ_COMMITTED,
+        )
+
+        state_groups_to_delete.extend(
+            await self.db_pool.runInteraction(
+                "purge_room",
+                self._purge_room_txn,
+                room_id=room_id,
+            ),
         )
 
+        return state_groups_to_delete
+
     def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
+        # This collides with event persistence so we cannot write new events and metadata into
+        # a room while deleting it or this transaction will fail.
+        if isinstance(self.database_engine, PostgresEngine):
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
+                (room_id,),
+            )
+
         # First, fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 86649c1e6c..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+)
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _is_experimental_rule_enabled(
-    rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
-    """Used by `_load_rules` to filter out experimental rules when they
-    have not been enabled.
-    """
-    if (
-        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
-        and not experimental_config.msc3786_enabled
-    ):
-        return False
-    if (
-        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
-        and not experimental_config.msc3772_enabled
-    ):
-        return False
-    return True
-
-
 def _load_rules(
     rawrules: List[JsonDict],
     enabled_map: Dict[str, bool],
     experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = db_to_json(rawrule["conditions"])
-        rule["actions"] = db_to_json(rawrule["actions"])
-        rule["default"] = False
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so copy it. We also filter out
-    # any experimental default push rules that aren't enabled.
-    rules = [
-        rule
-        for rule in list_with_base_rules(ruleslist)
-        if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
-    ]
+) -> FilteredPushRules:
+    """Take the DB rows returned from the DB and convert them into a full
+    `FilteredPushRules` object.
+    """
 
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
+    ruleslist = [
+        PushRule(
+            rule_id=rawrule["rule_id"],
+            priority_class=rawrule["priority_class"],
+            conditions=db_to_json(rawrule["conditions"]),
+            actions=db_to_json(rawrule["actions"]),
+        )
+        for rawrule in rawrules
+    ]
 
-        if rule_id not in enabled_map:
-            continue
-        if rule.get("enabled", True) == bool(enabled_map[rule_id]):
-            continue
+    push_rules = compile_push_rules(ruleslist)
 
-        # Rules are cached across users.
-        rule = dict(rule)
-        rule["enabled"] = bool(enabled_map[rule_id])
-        rules[i] = rule
+    filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
 
-    return rules
+    return filtered_rules
 
 
 # The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+    async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
 
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
-    @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
     @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
     async def bulk_get_push_rules(
         self, user_ids: Collection[str]
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, FilteredPushRules]:
         if not user_ids:
             return {}
 
-        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+        raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -228,25 +209,25 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("*",),
             desc="bulk_get_push_rules",
+            batch_size=1000,
         )
 
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
         for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
+            raw_rules.setdefault(row["user_name"], []).append(row)
 
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
-        for user_id, rules in results.items():
+        results: Dict[str, FilteredPushRules] = {}
+
+        for user_id, rules in raw_rules.items():
             results[user_id] = _load_rules(
                 rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
             )
 
         return results
 
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
-    )
     async def bulk_get_push_rules_enabled(
         self, user_ids: Collection[str]
     ) -> Dict[str, Dict[str, bool]]:
@@ -261,6 +242,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("user_name", "rule_id", "enabled"),
             desc="bulk_get_push_rules_enabled",
+            batch_size=1000,
         )
         for row in rows:
             enabled = bool(row["enabled"])
@@ -344,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
         user_id: str,
         rule_id: str,
         priority_class: int,
-        conditions: List[Dict[str, str]],
-        actions: List[Union[JsonDict, str]],
+        conditions: Sequence[Mapping[str, str]],
+        actions: Sequence[Union[Mapping[str, Any], str]],
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
@@ -807,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
         self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
         txn.call_after(
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
@@ -816,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
+        self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
@@ -826,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
             rule: A push rule.
         """
         # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
 
+        new_conditions = []
+
         # Change room id in each condition
-        for condition in rule.get("conditions", []):
+        for condition in rule.conditions:
+            new_condition = condition
             if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
+                new_condition = dict(condition)
+                new_condition["pattern"] = new_room_id
+
+            new_conditions.append(new_condition)
 
         # Add the rule for the new room
         await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
+            priority_class=rule.priority_class,
+            conditions=new_conditions,
+            actions=rule.actions,
         )
 
     async def copy_push_rules_from_room_to_room_for_user(
@@ -858,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
         user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
+        for rule, enabled in user_push_rules:
+            if not enabled:
+                continue
+
+            conditions = rule.conditions
             if any(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..ddb8e80b69 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore):
             prefilled_cache=receipts_stream_prefill,
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "receipts_linearized_unique_index",
+            index_name="receipts_linearized_unique_index",
+            table="receipts_linearized",
+            columns=["room_id", "receipt_type", "user_id"],
+            where_clause="thread_id IS NULL",
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            "receipts_graph_unique_index",
+            index_name="receipts_graph_unique_index",
+            table="receipts_graph",
+            columns=["room_id", "receipt_type", "user_id"],
+            where_clause="thread_id IS NULL",
+            unique=True,
+        )
+
     def get_max_receipt_stream_id(self) -> int:
         """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
@@ -161,7 +179,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type: The receipt types to fetch.
 
         Returns:
-            The latest receipt, if one exists.
+            The event ID and stream ordering of the latest receipt, if one exists.
         """
 
         clause, args = make_in_list_sql_clause(
@@ -675,7 +693,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
             values={
                 "stream_id": stream_id,
                 "event_id": event_id,
+                "event_stream_ordering": stream_ordering,
                 "data": json_encoder.encode(data),
+                "thread_id": None,
             },
             # receipts_linearized has a unique constraint on
             # (user_id, room_id, receipt_type), so no need to lock
@@ -812,7 +832,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
-        self.db_pool.simple_delete_txn(
+        self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_graph",
             keyvalues={
@@ -820,19 +840,87 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "receipt_type": receipt_type,
                 "user_id": user_id,
             },
-        )
-        self.db_pool.simple_insert_txn(
-            txn,
-            table="receipts_graph",
             values={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
                 "event_ids": json_encoder.encode(event_ids),
                 "data": json_encoder.encode(data),
+                "thread_id": None,
             },
+            # receipts_graph has a unique constraint on
+            # (user_id, room_id, receipt_type), so no need to lock
+            lock=False,
+        )
+
+
+class ReceiptsBackgroundUpdateStore(SQLBaseStore):
+    POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering"
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_update_handler(
+            self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
+            self._populate_receipt_event_stream_ordering,
+        )
+
+    async def _populate_receipt_event_stream_ordering(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def _populate_receipt_event_stream_ordering_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+
+            if "max_stream_id" in progress:
+                max_stream_id = progress["max_stream_id"]
+            else:
+                txn.execute("SELECT max(stream_id) FROM receipts_linearized")
+                res = txn.fetchone()
+                if res is None or res[0] is None:
+                    return True
+                else:
+                    max_stream_id = res[0]
+
+            start = progress.get("stream_id", 0)
+            stop = start + batch_size
+
+            sql = """
+                UPDATE receipts_linearized
+                SET event_stream_ordering = (
+                    SELECT stream_ordering
+                    FROM events
+                    WHERE event_id = receipts_linearized.event_id
+                )
+                WHERE stream_id >= ? AND stream_id < ?
+            """
+            txn.execute(sql, (start, stop))
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
+                {
+                    "stream_id": stop,
+                    "max_stream_id": max_stream_id,
+                },
+            )
+
+            return stop > max_stream_id
+
+        finished = await self.db_pool.runInteraction(
+            "_remove_devices_from_device_inbox_txn",
+            _populate_receipt_event_stream_ordering_txn,
         )
 
+        if finished:
+            await self.db_pool.updates._end_background_update(
+                self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING
+            )
+
+        return batch_size
+
 
-class ReceiptsStore(ReceiptsWorkerStore):
+class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore):
     pass
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..ac821878b0 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
     """
 
     user_id: str
+    token_id: int
     is_guest: bool = False
     shadow_banned: bool = False
-    token_id: Optional[int] = None
     device_id: Optional[str] = None
     valid_until_ms: Optional[int] = None
     token_owner: str = attr.ib()
@@ -175,6 +175,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 "is_guest",
                 "admin",
                 "consent_version",
+                "consent_ts",
                 "consent_server_notice_sent",
                 "appservice_id",
                 "creation_ts",
@@ -2227,7 +2228,10 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
-                updatevalues={"consent_version": consent_version},
+                updatevalues={
+                    "consent_version": consent_version,
+                    "consent_ts": self._clock.time_msec(),
+                },
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
             where_clause.append("type = ?")
             where_args.append(event_type)
 
-        if aggregation_key:
-            where_clause.append("aggregation_key = ?")
-            where_args.append(aggregation_key)
-
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 13d6a1d5c0..bef66f1992 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -175,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                   rooms.creator, state.encryption, state.is_federatable AS federatable,
                   rooms.is_public AS public, state.join_rules, state.guest_access,
                   state.history_visibility, curr.current_state_events AS state_events,
-                  state.avatar, state.topic
+                  state.avatar, state.topic, state.room_type
                 FROM rooms
                 LEFT JOIN room_stats_state state USING (room_id)
                 LEFT JOIN room_stats_current curr USING (room_id)
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
     def _construct_room_type_where_clause(
         self, room_types: Union[List[Union[str, None]], None]
     ) -> Tuple[Union[str, None], List[str]]:
-        if not room_types or not self.config.experimental.msc3827_enabled:
+        if not room_types:
             return None, []
         else:
             # We use None when we want get rooms without a type
@@ -596,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
               curr.local_users_in_room, rooms.room_version, rooms.creator,
               state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
-              state.guest_access, state.history_visibility, curr.current_state_events
+              state.guest_access, state.history_visibility, curr.current_state_events,
+              state.room_type
             FROM room_stats_state state
             INNER JOIN room_stats_current curr USING (room_id)
             INNER JOIN rooms USING (room_id)
@@ -640,12 +641,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         "version": room[5],
                         "creator": room[6],
                         "encryption": room[7],
-                        "federatable": room[8],
-                        "public": room[9],
+                        # room_stats_state.federatable is an integer on sqlite.
+                        "federatable": bool(room[8]),
+                        # rooms.is_public is an integer on sqlite.
+                        "public": bool(room[9]),
                         "join_rules": room[10],
                         "guest_access": room[11],
                         "history_visibility": room[12],
                         "state_events": room[13],
+                        "room_type": room[14],
                     }
                 )
 
@@ -1181,8 +1185,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             )
             return False
 
-    @staticmethod
-    def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+    def _clear_partial_state_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> None:
         DatabasePool.simple_delete_txn(
             txn,
             table="partial_state_rooms_servers",
@@ -1193,7 +1198,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             table="partial_state_rooms",
             keyvalues={"room_id": room_id},
         )
+        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
 
+    @cached()
     async def is_partial_state_room(self, room_id: str) -> bool:
         """Checks if this room has partial state.
 
@@ -1767,9 +1774,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             servers,
         )
 
-    @staticmethod
     def _store_partial_state_room_txn(
-        txn: LoggingTransaction, room_id: str, servers: Collection[str]
+        self, txn: LoggingTransaction, room_id: str, servers: Collection[str]
     ) -> None:
         DatabasePool.simple_insert_txn(
             txn,
@@ -1784,6 +1790,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             keys=("room_id", "server_name"),
             values=((room_id, s) for s in servers),
         )
+        self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
 
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
@@ -1999,9 +2006,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
             sql = """
                 SELECT COUNT(*) as total_event_reports
                 FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
                 {}
                 """.format(
                 where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..a8d224602a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -30,13 +31,8 @@ from typing import (
 import attr
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import (
-    run_as_background_process,
-    wrap_as_background_process,
-)
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -56,6 +52,8 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -90,15 +88,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # at a time. Keyed by room_id.
         self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
 
-        # Is the current_state_events.membership up to date? Or is the
-        # background update still running?
-        self._current_state_events_membership_up_to_date = False
-
-        txn = db_conn.cursor(
-            txn_name="_check_safe_current_state_events_membership_updated"
-        )
-        self._check_safe_current_state_events_membership_updated_txn(txn)
-        txn.close()
+        self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
 
         if (
             self.hs.config.worker.run_background_tasks
@@ -156,58 +146,41 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self._known_servers_count = max([count, 1])
         return self._known_servers_count
 
-    def _check_safe_current_state_events_membership_updated_txn(
-        self, txn: LoggingTransaction
-    ) -> None:
-        """Checks if it is safe to assume the new current_state_events
-        membership column is up to date
-        """
-
-        pending_update = self.db_pool.simple_select_one_txn(
-            txn,
-            table="background_updates",
-            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
-            retcols=["update_name"],
-            allow_none=True,
-        )
-
-        self._current_state_events_membership_up_to_date = not pending_update
-
-        # If the update is still running, reschedule to run.
-        if pending_update:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "_check_safe_current_state_events_membership_updated",
-                self.db_pool.runInteraction,
-                "_check_safe_current_state_events_membership_updated",
-                self._check_safe_current_state_events_membership_updated_txn,
-            )
-
-    @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+
+        Uses `m.room.member`s in the room state at the current forward extremities to
+        determine which users are in the room.
+
+        Will return inaccurate results for rooms with partial state, since the state for
+        the forward extremities of those rooms will exclude most members. We may also
+        calculate room state incorrectly for such rooms and believe that a member is or
+        is not in the room when the opposite is true.
+        """
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
     def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
-        # If we can assume current_state_events.membership is up to date
-        # then we can avoid a join, which is a Very Good Thing given how
-        # frequently this function gets called.
-        if self._current_state_events_membership_up_to_date:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
-            """
-        else:
-            sql = """
-                SELECT state_key FROM room_memberships as m
-                INNER JOIN current_state_events as c
-                ON m.event_id = c.event_id
-                AND m.room_id = c.room_id
-                AND m.user_id = c.state_key
-                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
-            """
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
+        sql = """
+            SELECT c.state_key FROM current_state_events as c
+            /* Get the depth of the event from the events table */
+            INNER JOIN events AS e USING (event_id)
+            WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+            /* Sorted by lowest depth first */
+            ORDER BY e.depth ASC;
+        """
 
         txn.execute(sql, (room_id, Membership.JOIN))
         return [r[0] for r in txn]
@@ -244,7 +217,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn: LoggingTransaction,
         ) -> Dict[str, ProfileInfo]:
             clause, ids = make_in_list_sql_clause(
-                self.database_engine, "m.user_id", user_ids
+                self.database_engine, "c.state_key", user_ids
             )
 
             sql = """
@@ -282,6 +255,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             A mapping from user ID to ProfileInfo.
+
+        Preconditions:
+          - There is full state available for the room (it is not partial-stated).
         """
 
         def _get_users_in_room_with_profiles(
@@ -321,28 +297,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
 
-            # If we can assume current_state_events.membership is up to date
-            # then we can avoid a join, which is a Very Good Thing given how
-            # frequently this function gets called.
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT count(*), membership FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    GROUP BY membership
-                """
-            else:
-                sql = """
-                    SELECT count(*), m.membership FROM room_memberships as m
-                    INNER JOIN current_state_events as c
-                    ON m.event_id = c.event_id
-                    AND m.room_id = c.room_id
-                    AND m.user_id = c.state_key
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    GROUP BY m.membership
-                """
+            # Note, rejected events will have a null membership field, so
+            # we we manually filter them out.
+            sql = """
+                SELECT count(*), membership FROM current_state_events
+                WHERE type = 'm.room.member' AND room_id = ?
+                    AND membership IS NOT NULL
+                GROUP BY membership
+            """
 
             txn.execute(sql, (room_id,))
             res: Dict[str, MemberSummary] = {}
@@ -351,30 +313,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             # we order by membership and then fairly arbitrarily by event_id so
             # heroes are consistent
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT state_key, membership, event_id
-                    FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    ORDER BY
-                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        event_id ASC
-                    LIMIT ?
-                """
-            else:
-                sql = """
-                    SELECT c.state_key, m.membership, c.event_id
-                    FROM room_memberships as m
-                    INNER JOIN current_state_events as c USING (room_id, event_id)
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    ORDER BY
-                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        c.event_id ASC
-                    LIMIT ?
-                """
+            # Note, rejected events will have a null membership field, so
+            # we we manually filter them out.
+            sql = """
+                SELECT state_key, membership, event_id
+                FROM current_state_events
+                WHERE type = 'm.room.member' AND room_id = ?
+                    AND membership IS NOT NULL
+                ORDER BY
+                    CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                    event_id ASC
+                LIMIT ?
+            """
 
             # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
             txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
@@ -530,6 +480,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_local_users_in_room",
         )
 
+    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+        """
+        Check whether a given local user is currently joined to the given room.
+
+        Returns:
+            A boolean indicating whether the user is currently joined to the room
+
+        Raises:
+            Exeption when called with a non-local user to this homeserver
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'check_local_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        (
+            membership,
+            member_event_id,
+        ) = await self.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
+
+        return membership == Membership.JOIN
+
+    async def is_server_notice_room(self, room_id: str) -> bool:
+        """
+        Determines whether the given room is a 'Server Notices' room, used for
+        sending server notices to a user.
+
+        This is determined by seeing whether the server notices user is present
+        in the room.
+        """
+        if self._server_notices_mxid is None:
+            return False
+        is_server_notices_room = await self.check_local_user_in_room(
+            user_id=self._server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
+
     async def get_local_current_membership_for_user_in_room(
         self, user_id: str, room_id: str
     ) -> Tuple[Optional[str], Optional[str]]:
@@ -562,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results_dict.get("membership"), results_dict.get("event_id")
 
-    @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -591,27 +582,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
         # for rooms the server is participating in.
-        if self._current_state_events_membership_up_to_date:
-            sql = """
-                SELECT room_id, e.instance_name, e.stream_ordering
-                FROM current_state_events AS c
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    c.type = 'm.room.member'
-                    AND c.state_key = ?
-                    AND c.membership = ?
-            """
-        else:
-            sql = """
-                SELECT room_id, e.instance_name, e.stream_ordering
-                FROM current_state_events AS c
-                INNER JOIN room_memberships AS m USING (room_id, event_id)
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    c.type = 'm.room.member'
-                    AND c.state_key = ?
-                    AND m.membership = ?
-            """
+        sql = """
+            SELECT room_id, e.instance_name, e.stream_ordering
+            FROM current_state_events AS c
+            INNER JOIN events AS e USING (room_id, event_id)
+            WHERE
+                c.type = 'm.room.member'
+                AND c.state_key = ?
+                AND c.membership = ?
+        """
 
         txn.execute(sql, (user_id, Membership.JOIN))
         return frozenset(
@@ -649,27 +628,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_ids,
         )
 
-        if self._current_state_events_membership_up_to_date:
-            sql = f"""
-                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
-                FROM current_state_events AS c
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    c.type = 'm.room.member'
-                    AND c.membership = ?
-                    AND {clause}
-            """
-        else:
-            sql = f"""
-                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
-                FROM current_state_events AS c
-                INNER JOIN room_memberships AS m USING (room_id, event_id)
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    c.type = 'm.room.member'
-                    AND m.membership = ?
-                    AND {clause}
-            """
+        sql = f"""
+            SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+            FROM current_state_events AS c
+            INNER JOIN events AS e USING (room_id, event_id)
+            WHERE
+                c.type = 'm.room.member'
+                AND c.membership = ?
+                AND {clause}
+        """
 
         txn.execute(sql, [Membership.JOIN] + args)
 
@@ -720,6 +687,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             _get_users_server_still_shares_room_with_txn,
         )
 
+    @cancellable
     async def get_rooms_for_user(
         self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
     ) -> FrozenSet[str]:
@@ -733,25 +701,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
         return frozenset(r.room_id for r in rooms)
 
-    @cached(
-        max_entries=500000,
-        cache_context=True,
-        iterable=True,
-        prune_unread_entries=False,
+    @cached(max_entries=10000)
+    async def does_pair_of_users_share_a_room(
+        self, user_id: str, other_user_id: str
+    ) -> bool:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
     )
-    async def get_users_who_share_room_with_user(
-        self, user_id: str, cache_context: _CacheContext
+    async def _do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
+    ) -> Mapping[str, Optional[bool]]:
+        """Return mapping from user ID to whether they share a room with the
+        given user.
+
+        Note: `None` and `False` are equivalent and mean they don't share a
+        room.
+        """
+
+        def do_users_share_a_room_txn(
+            txn: LoggingTransaction, user_ids: Collection[str]
+        ) -> Dict[str, bool]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            # This query works by fetching both the list of rooms for the target
+            # user and the set of other users, and then checking if there is any
+            # overlap.
+            sql = f"""
+                SELECT b.state_key
+                FROM (
+                    SELECT room_id FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+                ) AS a
+                INNER JOIN (
+                    SELECT room_id, state_key FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+                ) AS b using (room_id)
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, *args))
+            return {u: True for u, in txn}
+
+        to_return = {}
+        for batch_user_ids in batch_iter(other_user_ids, 1000):
+            res = await self.db_pool.runInteraction(
+                "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+            )
+            to_return.update(res)
+
+        return to_return
+
+    async def do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
     ) -> Set[str]:
+        """Return the set of users who share a room with the first users"""
+
+        user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+        return {u for u, share_room in user_dict.items() if share_room}
+
+    async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
         """Returns the set of users who share a room with `user_id`"""
-        room_ids = await self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
+        room_ids = await self.get_rooms_for_user(user_id)
 
         user_who_share_room = set()
         for room_id in room_ids:
-            user_ids = await self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
+            user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
 
         return user_who_share_room
@@ -780,161 +799,92 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
+    async def get_joined_user_ids_from_state(
+        self, room_id: str, state: StateMap[str]
+    ) -> Set[str]:
+        """
+        For a given set of state IDs, get a set of user IDs in the room.
 
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
+        This method checks the local event cache, before calling
+        `_get_user_ids_from_membership_event_ids` for any uncached events.
+        """
 
-    async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
+        with Measure(self._clock, "get_joined_user_ids_from_state"):
+            users_in_room = set()
+            member_event_ids = [
+                e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+            ]
 
-        assert state_group is not None
-        with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
-            )
+            # We check if we have any of the member event ids in the event cache
+            # before we ask the DB
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
-        self,
-        room_id: str,
-        state_group: Union[object, int],
-        current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
-        event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
-    ) -> Dict[str, ProfileInfo]:
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        assert state_group is not None
-
-        users_in_room = {}
-        member_event_ids = [
-            e_id
-            for key, e_id in current_state_ids.items()
-            if key[0] == EventTypes.Member
-        ]
-
-        if context is not None:
-            # If we have a context with a delta from a previous state group,
-            # check if we also have the result from the previous group in cache.
-            # If we do then we can reuse that result and simply update it with
-            # any membership changes in `delta_ids`
-            if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
-                    (room_id, context.prev_group), None
-                )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
-                    member_event_ids = [
-                        e_id
-                        for key, e_id in context.delta_ids.items()
-                        if key[0] == EventTypes.Member
-                    ]
-                    for etype, state_key in context.delta_ids:
-                        if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
-
-        # We check if we have any of the member event ids in the event cache
-        # before we ask the DB
-
-        # We don't update the event cache hit ratio as it completely throws off
-        # the hit ratio counts. After all, we don't populate the cache if we
-        # miss it here
-        event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
-
-        missing_member_event_ids = []
-        for event_id in member_event_ids:
-            ev_entry = event_map.get(event_id)
-            if ev_entry and not ev_entry.event.rejected_reason:
-                if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
-            else:
-                missing_member_event_ids.append(event_id)
-
-        if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
-                missing_member_event_ids
+            # We don't update the event cache hit ratio as it completely throws off
+            # the hit ratio counts. After all, we don't populate the cache if we
+            # miss it here
+            event_map = self._get_events_from_local_cache(
+                member_event_ids, update_metrics=False
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
-
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
+
+            missing_member_event_ids = []
+            for event_id in member_event_ids:
+                ev_entry = event_map.get(event_id)
+                if ev_entry and not ev_entry.event.rejected_reason:
+                    if ev_entry.event.membership == Membership.JOIN:
+                        users_in_room.add(ev_entry.event.state_key)
+                else:
+                    missing_member_event_ids.append(event_id)
+
+            if missing_member_event_ids:
+                event_to_memberships = (
+                    await self._get_user_ids_from_membership_event_ids(
+                        missing_member_event_ids
                     )
+                )
+                users_in_room.update(
+                    user_id for user_id in event_to_memberships.values() if user_id
+                )
 
-        return users_in_room
+            return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
+        event.
 
         Args:
             event_ids: The member event IDs to lookup
 
         Returns:
-            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id`, or None if event is not a join.
         """
 
         rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
-            batch_size=500,
-            desc="_get_joined_profiles_from_event_ids",
+            batch_size=1000,
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -980,44 +930,88 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
-        """Get current hosts in room based on current state."""
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+        """
+        Get current hosts in room based on current state.
+
+        The heuristic of sorting by servers who have been in the room the
+        longest is good because they're most likely to have anything we ask
+        about.
+
+        Uses `m.room.member`s in the room state at the current forward extremities to
+        determine which hosts are in the room.
+
+        Will return inaccurate results for rooms with partial state, since the state for
+        the forward extremities of those rooms will exclude most members. We may also
+        calculate room state incorrectly for such rooms and believe that a host is or
+        is not in the room when the opposite is true.
+
+        Returns:
+            Returns a list of servers sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         # First we check if we already have `get_users_in_room` in the cache, as
         # we can just calculate result from that
         users = self.get_users_in_room.cache.get_immediate(
             (room_id,), None, update_metrics=False
         )
-        if users is not None:
-            return {get_domain_from_id(u) for u in users}
-
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if users is None and isinstance(self.database_engine, Sqlite3Engine):
             # If we're using SQLite then let's just always use
             # `get_users_in_room` rather than funky SQL.
             users = await self.get_users_in_room(room_id)
-            return {get_domain_from_id(u) for u in users}
+
+        if users is not None:
+            # Because `users` is sorted from lowest -> highest depth, the list
+            # of domains will also be sorted that way.
+            domains: List[str] = []
+            # We use a `Set` just for fast lookups
+            domain_set: Set[str] = set()
+            for u in users:
+                if ":" not in u:
+                    continue
+                domain = get_domain_from_id(u)
+                if domain not in domain_set:
+                    domain_set.add(domain)
+                    domains.append(domain)
+            return domains
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+            # Returns a list of servers currently joined in the room sorted by
+            # longest in the room first (aka. with the lowest depth). The
+            # heuristic of sorting by servers who have been in the room the
+            # longest is good because they're most likely to have anything we
+            # ask about.
             sql = """
-                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
-                FROM current_state_events
+                SELECT
+                    /* Match the domain part of the MXID */
+                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+                FROM current_state_events c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND room_id = ?
+                    /* Find any join state events in the room */
+                    c.type = 'm.room.member'
+                    AND c.membership = 'join'
+                    AND c.room_id = ?
+                /* Group all state events from the same domain into their own buckets (groups) */
+                GROUP BY server_domain
+                /* Sorted by lowest depth first */
+                ORDER BY min(e.depth) ASC;
             """
             txn.execute(sql, (room_id,))
-            return {d for d, in txn}
+            # `server_domain` will be `NULL` for malformed MXIDs with no colons.
+            return [d for d, in txn if d is not None]
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room", get_current_hosts_in_room_txn
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -1030,7 +1024,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1032,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1092,12 +1087,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                joined_user_ids = await self.get_joined_user_ids_from_state(
+                    room_id, state
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
@@ -1176,6 +1171,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
+    async def is_locally_forgotten_room(self, room_id: str) -> bool:
+        """Returns whether all local users have forgotten this room_id.
+
+        Args:
+            room_id: The room ID to query.
+
+        Returns:
+            Whether the room is forgotten.
+        """
+
+        sql = """
+            SELECT count(*) > 0 FROM local_current_membership
+            INNER JOIN room_memberships USING (room_id, event_id)
+            WHERE
+                room_id = ?
+                AND forgotten = 0;
+        """
+
+        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+        # `count(*)` returns always an integer
+        # If any rows still exist it means someone has not forgotten this room yet
+        return not rows[0][0]
+
     async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
         """Get all rooms that the user has ever been in.
 
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..af7bebee80 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -23,6 +23,7 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -36,6 +37,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, JsonMapping, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -142,6 +144,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return room_version
 
+    @trace
     async def get_metadata_for_events(
         self, event_ids: Collection[str]
     ) -> Dict[str, EventMetadata]:
@@ -281,6 +284,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
+    @cancellable
     async def get_partial_filtered_current_state_ids(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
@@ -419,15 +423,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # anything that was rejected should have the same state as its
         # predecessor.
         if context.rejected:
-            assert context.state_group == context.state_group_before_event
+            state_group = context.state_group_before_event
+        else:
+            state_group = context.state_group
 
         self.db_pool.simple_update_txn(
             txn,
             table="event_to_state_groups",
             keyvalues={"event_id": event.event_id},
-            updatevalues={"state_group": context.state_group},
+            updatevalues={"state_group": state_group},
         )
 
+        # the event may now be rejected where it was not before, or vice versa,
+        # in which case we need to update the rejected flags.
+        if bool(context.rejected) != (event.rejected_reason is not None):
+            self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
         self.db_pool.simple_delete_one_txn(
             txn,
             table="partial_state_events",
@@ -440,7 +451,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         txn.call_after(
             self._get_state_group_for_event.prefill,
             (event.event_id,),
-            context.state_group,
+            state_group,
         )
 
 
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index b4c652acf3..356d4ca788 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -446,59 +446,41 @@ class StatsStore(StateDeltasStore):
             absolutes: Absolute (set) fields
             additive_relatives: Fields that will be added onto if existing row present.
         """
-        if self.database_engine.can_native_upsert:
-            absolute_updates = [
-                "%(field)s = EXCLUDED.%(field)s" % {"field": field}
-                for field in absolutes.keys()
-            ]
-
-            relative_updates = [
-                "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
-                % {"table": table, "field": field}
-                for field in additive_relatives.keys()
-            ]
-
-            insert_cols = []
-            qargs = []
-
-            for (key, val) in chain(
-                keyvalues.items(), absolutes.items(), additive_relatives.items()
-            ):
-                insert_cols.append(key)
-                qargs.append(val)
+        absolute_updates = [
+            "%(field)s = EXCLUDED.%(field)s" % {"field": field}
+            for field in absolutes.keys()
+        ]
+
+        relative_updates = [
+            "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
+            % {"table": table, "field": field}
+            for field in additive_relatives.keys()
+        ]
+
+        insert_cols = []
+        qargs = []
+
+        for (key, val) in chain(
+            keyvalues.items(), absolutes.items(), additive_relatives.items()
+        ):
+            insert_cols.append(key)
+            qargs.append(val)
+
+        sql = """
+            INSERT INTO %(table)s (%(insert_cols_cs)s)
+            VALUES (%(insert_vals_qs)s)
+            ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
+        """ % {
+            "table": table,
+            "insert_cols_cs": ", ".join(insert_cols),
+            "insert_vals_qs": ", ".join(
+                ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
+            ),
+            "key_columns": ", ".join(keyvalues),
+            "updates": ", ".join(chain(absolute_updates, relative_updates)),
+        }
 
-            sql = """
-                INSERT INTO %(table)s (%(insert_cols_cs)s)
-                VALUES (%(insert_vals_qs)s)
-                ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
-            """ % {
-                "table": table,
-                "insert_cols_cs": ", ".join(insert_cols),
-                "insert_vals_qs": ", ".join(
-                    ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
-                ),
-                "key_columns": ", ".join(keyvalues),
-                "updates": ", ".join(chain(absolute_updates, relative_updates)),
-            }
-
-            txn.execute(sql, qargs)
-        else:
-            self.database_engine.lock_table(txn, table)
-            retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
-            current_row = self.db_pool.simple_select_one_txn(
-                txn, table, keyvalues, retcols, allow_none=True
-            )
-            if current_row is None:
-                merged_dict = {**keyvalues, **absolutes, **additive_relatives}
-                self.db_pool.simple_insert_txn(txn, table, merged_dict)
-            else:
-                for (key, val) in additive_relatives.items():
-                    if current_row[key] is None:
-                        current_row[key] = val
-                    else:
-                        current_row[key] += val
-                current_row.update(absolutes)
-                self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
+        txn.execute(sql, qargs)
 
     async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
         """Calculate and insert an entry into room_stats_current.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3a1df7776c..3f9bfaeac5 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -58,6 +58,7 @@ from twisted.internet import defer
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -71,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import PersistedEventPosition, RoomStreamToken
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -596,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
+    @cancellable
     async def get_membership_changes_for_user(
         self,
         user_id: str,
@@ -1022,8 +1025,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
+        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
         """Get all new events
 
         Returns all events with from_id < stream_ordering <= current_id.
@@ -1032,19 +1035,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
+            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events), where `next_id` is the next value to
-            pass as `from_id` (it will either be the stream_ordering of the
-            last returned event, or, if fewer than `limit` events were found,
-            the `current_id`).
+            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            is the next value to pass as `from_id` (it will either be the
+            stream_ordering of the last returned event, or, if fewer than `limit`
+            events were found, the `current_id`). The `event_to_received_ts` is
+            a dictionary mapping event ID to the event `received_ts`.
         """
 
         def get_all_new_events_stream_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
+        ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
-                "SELECT e.stream_ordering, e.event_id"
+                "SELECT e.stream_ordering, e.event_id, e.received_ts"
                 " FROM events AS e"
                 " WHERE"
                 " ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1059,15 +1064,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if len(rows) == limit:
                 upper_bound = rows[-1][0]
 
-            return upper_bound, [row[1] for row in rows]
+            event_to_received_ts: Dict[str, Optional[int]] = {
+                row[1]: row[2] for row in rows
+            }
+            return upper_bound, event_to_received_ts
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
+        upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(
+            event_to_received_ts.keys(),
+            get_prev_content=get_prev_content,
+        )
 
-        return upper_bound, events
+        return upper_bound, events, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions:
@@ -1338,6 +1349,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, next_token
 
+    @trace
     async def paginate_room_events(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index ba79e19f7f..f8c6877ee8 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -221,25 +221,15 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             retry_interval: how long until next retry in ms
         """
 
-        if self.database_engine.can_native_upsert:
-            await self.db_pool.runInteraction(
-                "set_destination_retry_timings",
-                self._set_destination_retry_timings_native,
-                destination,
-                failure_ts,
-                retry_last_ts,
-                retry_interval,
-                db_autocommit=True,  # Safe as its a single upsert
-            )
-        else:
-            await self.db_pool.runInteraction(
-                "set_destination_retry_timings",
-                self._set_destination_retry_timings_emulated,
-                destination,
-                failure_ts,
-                retry_last_ts,
-                retry_interval,
-            )
+        await self.db_pool.runInteraction(
+            "set_destination_retry_timings",
+            self._set_destination_retry_timings_native,
+            destination,
+            failure_ts,
+            retry_last_ts,
+            retry_interval,
+            db_autocommit=True,  # Safe as it's a single upsert
+        )
 
     def _set_destination_retry_timings_native(
         self,
@@ -249,8 +239,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         retry_last_ts: int,
         retry_interval: int,
     ) -> None:
-        assert self.database_engine.can_native_upsert
-
         # Upsert retry time interval if retry_interval is zero (i.e. we're
         # resetting it) or greater than the existing retry interval.
         #
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.types import MutableStateMap, StateMap
+from synapse.util.caches import intern_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                 txn.execute(sql % (where_clause,), args)
                 for row in txn:
                     typ, state_key, event_id = row
-                    key = (typ, state_key)
+                    key = (intern_string(typ), intern_string(state_key))
                     results[group][key] = event_id
         else:
             max_entries_returned = state_filter.max_entries_returned()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..f8cfcaca83 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "get_state_group_delta", _get_state_group_delta_txn
         )
 
+    @cancellable
     async def _get_state_groups_from_groups(
         self, groups: List[int], state_filter: StateFilter
     ) -> Dict[int, StateMap[str]]:
@@ -202,7 +204,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 requests state from the cache, if False we need to query the DB for the
                 missing state.
         """
-        cache_entry = cache.get(group)
+        # If we are asked explicitly for a subset of keys, we only ask for those
+        # from the cache. This ensures that the `DictionaryCache` can make
+        # better decisions about what to cache and what to expire.
+        dict_keys = None
+        if not state_filter.has_wildcards():
+            dict_keys = state_filter.concrete_types()
+
+        cache_entry = cache.get(group, dict_keys=dict_keys)
         state_dict_ids = cache_entry.value
 
         if cache_entry.full or state_filter.is_full():
@@ -228,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
+    @cancellable
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
@@ -400,14 +410,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
+        At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+        `prev_group` is not None, `delta_ids` must also not be None.
+
         Args:
             event_id: The event ID for which the state was calculated
             room_id
-            prev_group: A previous state group for the room, optional.
+            prev_group: A previous state group for the room.
             delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
@@ -418,10 +431,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn: LoggingTransaction) -> int:
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
+        if prev_group is None and current_state_ids is None:
+            raise Exception("current_state_ids and prev_group can't both be None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("delta_ids is None when prev_group is not None")
+
+        def insert_delta_group_txn(
+            txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+        ) -> Optional[int]:
+            """Try and persist the new group as a delta.
+
+            Requires that we have the state as a delta from a previous state group.
+
+            Returns:
+                The state group if successfully created, or None if the state
+                needs to be persisted as a full state.
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            # if the chain of state group deltas is going too long, we fall back to
+            # persisting a complete state group.
+            potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if potential_hops >= MAX_STATE_DELTA_HOPS:
+                return None
 
             state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
@@ -431,51 +475,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
             )
 
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self.db_pool.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                assert delta_ids is not None
-
-                self.db_pool.simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_group_edges",
+                values={"state_group": state_group, "prev_state_group": prev_group},
+            )
 
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in delta_ids.items()
-                    ],
-                )
-            else:
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in current_state_ids.items()
-                    ],
-                )
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in delta_ids.items()
+                ],
+            )
+
+            return state_group
+
+        def insert_full_state_txn(
+            txn: LoggingTransaction, current_state_ids: StateMap[str]
+        ) -> int:
+            """Persist the full state, returning the new state group."""
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in current_state_ids.items()
+                ],
+            )
 
             # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
@@ -491,7 +529,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_members_cache.update,
                 self._state_group_members_cache.sequence,
                 key=state_group,
-                value=dict(current_member_state_ids),
+                value=current_member_state_ids,
             )
 
             current_non_member_state_ids = {
@@ -503,13 +541,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_non_member_state_ids),
+                value=current_non_member_state_ids,
             )
 
             return state_group
 
+        if prev_group is not None:
+            state_group = await self.db_pool.runInteraction(
+                "store_state_group.insert_delta_group",
+                insert_delta_group_txn,
+                prev_group,
+                delta_ids,
+            )
+            if state_group is not None:
+                return state_group
+
+        # We're going to persist the state as a complete group rather than
+        # a delta, so first we need to ensure we have loaded the state map
+        # from the database.
+        if current_state_ids is None:
+            assert prev_group is not None
+            assert delta_ids is not None
+            groups = await self._get_state_for_groups([prev_group])
+            current_state_ids = dict(groups[prev_group])
+            current_state_ids.update(delta_ids)
+
         return await self.db_pool.runInteraction(
-            "store_state_group", _store_state_group_txn
+            "store_state_group.insert_full_state",
+            insert_full_state_txn,
+            current_state_ids,
         )
 
     async def purge_unreferenced_state_groups(
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 971ff82693..0d16a419a4 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -45,14 +45,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
 
     @property
     @abc.abstractmethod
-    def can_native_upsert(self) -> bool:
-        """
-        Do we support native UPSERTs?
-        """
-        ...
-
-    @property
-    @abc.abstractmethod
     def supports_using_any_list(self) -> bool:
         """
         Do we support using `a = ANY(?)` and passing a list
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 517f9d5f98..7f7d006ac2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -159,13 +159,6 @@ class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
         db_conn.commit()
 
     @property
-    def can_native_upsert(self) -> bool:
-        """
-        Can we use native UPSERTs?
-        """
-        return True
-
-    @property
     def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return True
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 621f2c5efe..095ae0a096 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -49,14 +49,6 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
         return True
 
     @property
-    def can_native_upsert(self) -> bool:
-        """
-        Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
-        more work we haven't done yet to tell what was inserted vs updated.
-        """
-        return sqlite3.sqlite_version_info >= (3, 24, 0)
-
-    @property
     def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return False
@@ -70,12 +62,11 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
         self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False
     ) -> None:
         if not allow_outdated_version:
-            version = sqlite3.sqlite_version_info
             # Synapse is untested against older SQLite versions, and we don't want
             # to let users upgrade to a version of Synapse with broken support for their
             # sqlite version, because it risks leaving them with a half-upgraded db.
-            if version < (3, 22, 0):
-                raise RuntimeError("Synapse requires sqlite 3.22 or above.")
+            if sqlite3.sqlite_version_info < (3, 27, 0):
+                raise RuntimeError("Synapse requires sqlite 3.27 or above.")
 
     def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index dc237e3032..68e055c664 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 72  # remember to update the list below when updating
+SCHEMA_VERSION = 73  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -74,13 +74,22 @@ Changes in SCHEMA_VERSION = 71:
 
 Changes in SCHEMA_VERSION = 72:
     - event_edges.(room_id, is_state) are no longer written to.
+    - Tables related to groups are dropped.
+    - Unused column application_services_state.last_txn is dropped
+    - Cache invalidation stream id sequence now begins at 2 to match code expectation.
+
+Changes in SCHEMA_VERSION = 73;
+    - thread_id column is added to event_push_actions, event_push_actions_staging
+      event_push_summary, receipts_linearized, and receipts_graph.
+    - Add table `event_failed_pull_attempts` to keep track when we fail to pull
+      events over federation.
 """
 
 
 SCHEMA_COMPAT_VERSION = (
-    # We no longer maintain `event_edges.room_id`, so synapses with SCHEMA_VERSION < 71
-    # will break.
-    71
+    # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72
+    # could break.
+    72
 )
 """Limit on how far the synapse codebase can be rolled back without breaking db compat
 
diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
new file mode 100644
index 0000000000..55a5d092cc
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
@@ -0,0 +1,47 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine, *args, **kwargs):
+    """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`"""
+
+    # we know that any new events will have the columns populated (and that has been
+    # the case since schema_version 68, so there is no chance of rolling back now).
+    #
+    # So, we only need to make sure that existing rows are updated. We read the
+    # current min and max stream orderings, since that is guaranteed to include all
+    # the events that were stored before the new columns were added.
+    cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events")
+    (min_stream_ordering, max_stream_ordering) = cur.fetchone()
+
+    if min_stream_ordering is None:
+        # no rows, nothing to do.
+        return
+
+    cur.execute(
+        "INSERT into background_updates (ordering, update_name, progress_json)"
+        " VALUES (7203, 'events_populate_state_key_rejections', ?)",
+        (
+            json.dumps(
+                {
+                    "min_stream_ordering_exclusive": min_stream_ordering - 1,
+                    "max_stream_ordering_inclusive": max_stream_ordering,
+                }
+            ),
+        ),
+    )
diff --git a/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
new file mode 100644
index 0000000000..0da668aa3a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- event_reference_hashes is unused, so we can drop it
+DROP TABLE event_reference_hashes;
diff --git a/synapse/storage/schema/main/delta/72/03remove_groups.sql b/synapse/storage/schema/main/delta/72/03remove_groups.sql
new file mode 100644
index 0000000000..b7c5894de8
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03remove_groups.sql
@@ -0,0 +1,31 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Remove the tables which powered the unspecced groups/communities feature.
+DROP TABLE IF EXISTS group_attestations_remote;
+DROP TABLE IF EXISTS group_attestations_renewals;
+DROP TABLE IF EXISTS group_invites;
+DROP TABLE IF EXISTS group_roles;
+DROP TABLE IF EXISTS group_room_categories;
+DROP TABLE IF EXISTS group_rooms;
+DROP TABLE IF EXISTS group_summary_roles;
+DROP TABLE IF EXISTS group_summary_room_categories;
+DROP TABLE IF EXISTS group_summary_rooms;
+DROP TABLE IF EXISTS group_summary_users;
+DROP TABLE IF EXISTS group_users;
+DROP TABLE IF EXISTS groups;
+DROP TABLE IF EXISTS local_group_membership;
+DROP TABLE IF EXISTS local_group_updates;
+DROP TABLE IF EXISTS remote_profile_cache;
diff --git a/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres
new file mode 100644
index 0000000000..13d47de9e6
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.postgres
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop unused column application_services_state.last_txn
+ALTER table application_services_state DROP COLUMN last_txn;
\ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite
new file mode 100644
index 0000000000..3be1c88d72
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/04drop_column_application_services_state_last_txn.sql.sqlite
@@ -0,0 +1,40 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop unused column application_services_state.last_txn
+
+CREATE TABLE application_services_state2 (
+    as_id TEXT PRIMARY KEY NOT NULL,
+    state VARCHAR(5),
+    read_receipt_stream_id BIGINT,
+    presence_stream_id BIGINT,
+    to_device_stream_id BIGINT,
+    device_list_stream_id BIGINT
+);
+
+
+INSERT INTO application_services_state2 (
+    as_id,
+    state,
+    read_receipt_stream_id,
+    presence_stream_id,
+    to_device_stream_id,
+    device_list_stream_id
+)
+SELECT as_id, state, read_receipt_stream_id, presence_stream_id, to_device_stream_id,  device_list_stream_id
+FROM application_services_state;
+
+DROP TABLE application_services_state;
+ALTER TABLE application_services_state2 RENAME TO application_services_state;
diff --git a/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql
new file mode 100644
index 0000000000..2a822f4509
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 Beeper
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE receipts_linearized ADD COLUMN event_stream_ordering BIGINT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('populate_event_stream_ordering', '{}');
diff --git a/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql b/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql
new file mode 100644
index 0000000000..36b41049cd
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/05remove_unstable_private_read_receipts.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop previously received private read receipts so they do not accidentally
+-- get leaked to other users.
+DELETE FROM receipts_linearized WHERE receipt_type = 'org.matrix.msc2285.read.private';
+DELETE FROM receipts_graph WHERE receipt_type = 'org.matrix.msc2285.read.private';
diff --git a/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql
new file mode 100644
index 0000000000..609eb1750f
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql
@@ -0,0 +1,16 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD consent_ts bigint;
diff --git a/synapse/storage/schema/main/delta/72/06thread_notifications.sql b/synapse/storage/schema/main/delta/72/06thread_notifications.sql
new file mode 100644
index 0000000000..2f4f5dac7a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/06thread_notifications.sql
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a nullable column for thread ID to the event push actions tables; this
+-- will be filled in with a default value for any previously existing rows.
+--
+-- After migration this can be made non-nullable.
+
+ALTER TABLE event_push_actions_staging ADD COLUMN thread_id TEXT;
+ALTER TABLE event_push_actions ADD COLUMN thread_id TEXT;
+ALTER TABLE event_push_summary ADD COLUMN thread_id TEXT;
+
+-- Update the unique index for `event_push_summary`.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7006, 'event_push_summary_unique_index2', '{}');
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (7006, 'event_push_backfill_thread_id', '{}', 'event_push_summary_unique_index2');
diff --git a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py
new file mode 100644
index 0000000000..b5853d125c
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py
@@ -0,0 +1,52 @@
+# Copyright 2022 Beeper
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Forces through the `current_state_events_membership` background job so checks
+for its completion can be removed.
+
+Note the background job must still remain defined in the database class.
+"""
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    cur.execute("SELECT update_name FROM background_updates")
+    rows = cur.fetchall()
+    for row in rows:
+        if row[0] == "current_state_events_membership":
+            break
+    # No pending background job so nothing to do here
+    else:
+        return
+
+    # Populate membership field for all current_state_events, this may take
+    # a while but was originally handled via a background update in 2019.
+    cur.execute(
+        """
+        UPDATE current_state_events
+        SET membership = (
+            SELECT membership FROM room_memberships
+            WHERE event_id = current_state_events.event_id
+        )
+        """
+    )
+
+    # Finally, delete the background job because we've handled it above
+    cur.execute(
+        """
+        DELETE FROM background_updates
+        WHERE update_name = 'current_state_events_membership'
+        """
+    )
diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres
new file mode 100644
index 0000000000..55fff9e278
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a nullable column for thread ID to the receipts table; this allows a
+-- receipt per user, per room, as well as an unthreaded receipt (corresponding
+-- to a null thread ID).
+
+ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT;
+ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT;
+
+-- Rebuild the unique constraint with the thread_id.
+ALTER TABLE receipts_linearized
+    ADD CONSTRAINT receipts_linearized_uniqueness_thread
+        UNIQUE (room_id, receipt_type, user_id, thread_id);
+
+ALTER TABLE receipts_graph
+    ADD CONSTRAINT receipts_graph_uniqueness_thread
+        UNIQUE (room_id, receipt_type, user_id, thread_id);
diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite
new file mode 100644
index 0000000000..232f67deb4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite
@@ -0,0 +1,70 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Allow multiple receipts per user per room via a nullable thread_id column.
+--
+-- SQLite doesn't support modifying constraints to an existing table, so it must
+-- be recreated.
+
+-- Create the new tables.
+CREATE TABLE receipts_linearized_new (
+    stream_id BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    receipt_type TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    event_id TEXT NOT NULL,
+    thread_id TEXT,
+    event_stream_ordering BIGINT,
+    data TEXT NOT NULL,
+    CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id),
+    CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id)
+);
+
+CREATE TABLE receipts_graph_new (
+    room_id TEXT NOT NULL,
+    receipt_type TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    event_ids TEXT NOT NULL,
+    thread_id TEXT,
+    data TEXT NOT NULL,
+    CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id),
+    CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id)
+);
+
+-- Drop the old indexes.
+DROP INDEX IF EXISTS receipts_linearized_id;
+DROP INDEX IF EXISTS receipts_linearized_room_stream;
+DROP INDEX IF EXISTS receipts_linearized_user;
+
+-- Copy the data.
+INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data)
+    SELECT stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data
+    FROM receipts_linearized;
+INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data)
+    SELECT room_id, receipt_type, user_id, event_ids, data
+    FROM receipts_graph;
+
+-- Drop the old tables.
+DROP TABLE receipts_linearized;
+DROP TABLE receipts_graph;
+
+-- Rename the tables.
+ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized;
+ALTER TABLE receipts_graph_new RENAME TO receipts_graph;
+
+-- Create the indices.
+CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
+CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
+CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id );
diff --git a/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres b/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres
new file mode 100644
index 0000000000..69931fe971
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/08begin_cache_invalidation_seq_at_2.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+ 
+
+-- The sequence needs to begin at 2 because a bunch of code assumes that
+-- get_next_id_txn will return values >= 2, cf this comment:
+-- https://github.com/matrix-org/synapse/blob/b93bd95e8ab64d27ae26841020f62ee61272a5f2/synapse/storage/util/id_generators.py#L344
+
+SELECT setval('cache_invalidation_stream_seq', (
+    SELECT COALESCE(MAX(last_value), 1) FROM cache_invalidation_stream_seq
+));
diff --git a/synapse/storage/schema/main/delta/72/08thread_receipts.sql b/synapse/storage/schema/main/delta/72/08thread_receipts.sql
new file mode 100644
index 0000000000..e35b021f31
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/08thread_receipts.sql
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7007, 'receipts_linearized_unique_index', '{}');
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7007, 'receipts_graph_unique_index', '{}');
diff --git a/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite b/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite
new file mode 100644
index 0000000000..c8dfdf0218
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/09partial_indices.sql.sqlite
@@ -0,0 +1,56 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+ 
+-- SQLite needs to rebuild indices which use partial indices on Postgres, but
+-- previously did not use them on SQLite.
+
+-- Drop each index that was added with register_background_index_update AND specified
+-- a where_clause (that existed before this delta).
+
+-- From events_bg_updates.py
+DROP INDEX IF EXISTS event_contains_url_index;
+-- There is also a redactions_censored_redacts index, but that gets dropped.
+DROP INDEX IF EXISTS redactions_have_censored_ts;
+-- There is also a PostgreSQL only index (event_contains_url_index2)
+-- which gets renamed to event_contains_url_index.
+
+-- From roommember.py
+DROP INDEX IF EXISTS room_memberships_user_room_forgotten;
+
+-- From presence.py
+DROP INDEX IF EXISTS presence_stream_state_not_offline_idx;
+
+-- From media_repository.py
+DROP INDEX IF EXISTS local_media_repository_url_idx;
+
+-- From event_push_actions.py
+DROP INDEX IF EXISTS event_push_actions_highlights_index;
+-- There's also a event_push_actions_stream_highlight_index which was previously
+-- PostgreSQL-only.
+
+-- From state.py
+DROP INDEX IF EXISTS current_state_events_member_index;
+
+-- Re-insert the background jobs to re-create the indices.
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (7209, 'event_contains_url_index', '{}', NULL),
+  (7209, 'redactions_have_censored_ts_idx', '{}', NULL),
+  (7209, 'room_membership_forgotten_idx', '{}', NULL),
+  (7209, 'presence_stream_not_offline_index', '{}', NULL),
+  (7209, 'local_media_repository_url_idx', '{}', NULL),
+  (7209, 'event_push_actions_highlights_index', '{}', NULL),
+  (7209, 'event_push_actions_stream_highlight_index', '{}', NULL),
+  (7209, 'current_state_members_idx', '{}', NULL)
+ON CONFLICT (update_name) DO NOTHING;
diff --git a/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql b/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql
new file mode 100644
index 0000000000..d397ee1082
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/01event_failed_pull_attempts.sql
@@ -0,0 +1,29 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Add a table that keeps track of when we failed to pull an event over
+-- federation (via /backfill, `/event`, `/get_missing_events`, etc). This allows
+-- us to be more intelligent when we decide to retry (we don't need to fail over
+-- and over) and we can process that event in the background so we don't block
+-- on it each time.
+CREATE TABLE IF NOT EXISTS event_failed_pull_attempts(
+    room_id TEXT NOT NULL REFERENCES rooms (room_id),
+    event_id TEXT NOT NULL,
+    num_attempts INT NOT NULL,
+    last_attempt_ts BIGINT NOT NULL,
+    last_cause TEXT NOT NULL,
+    PRIMARY KEY (room_id, event_id)
+);
diff --git a/synapse/storage/schema/state/delta/30/state_stream.sql b/synapse/storage/schema/state/delta/30/state_stream.sql
deleted file mode 100644
index e85699e82e..0000000000
--- a/synapse/storage/schema/state/delta/30/state_stream.sql
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-/* We used to create a table called current_state_resets, but this is no
- * longer used and is removed in delta 54.
- */
-
-/* The outlier events that have aquired a state group typically through
- * backfill. This is tracked separately to the events table, as assigning a
- * state group change the position of the existing event in the stream
- * ordering.
- * However since a stream_ordering is assigned in persist_event for the
- * (event, state) pair, we can use that stream_ordering to identify when
- * the new state was assigned for the event.
- */
-CREATE TABLE IF NOT EXISTS ex_outlier_stream(
-    event_stream_ordering BIGINT PRIMARY KEY NOT NULL,
-    event_id TEXT NOT NULL,
-    state_group BIGINT NOT NULL
-);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index af3bab2c15..0004d955b4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,15 +539,6 @@ class StateFilter:
             is_mine_id: a callable which confirms if a given state_key matches a mxid
                of a local user
         """
-
-        # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
-        #  there may be circumstances in which we return a piece of state that, once we
-        #  resync the state, we discover is invalid. For example: if it turns out that
-        #  the sender of a piece of state wasn't actually in the room, then clearly that
-        #  state shouldn't have been returned.
-        #  We should at least add some tests around this to see what happens.
-        #  https://github.com/matrix-org/synapse/issues/13006
-
         # if we haven't requested membership events, then it depends on the value of
         # 'include_others'
         if EventTypes.Member not in self.types:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..2dfe4c0b66 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 # Cast safety: this corresponds to the types returned by the query above.
                 rows.extend(cast(Iterable[Tuple[str, int]], cur))
 
-            # Sort so that we handle rows in order for each instance.
-            rows.sort()
+            # Sort by stream_id (ascending, lowest -> highest) so that we handle
+            # rows in order for each instance because we don't want to overwrite
+            # the current_position of an instance to a lower stream ID than
+            # we're actually at.
+            def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+                (instance, stream_id) = row
+                # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+                # not supported between instances of 'NoneType' and 'X'` error.
+                return stream_id
+
+            rows.sort(key=sort_by_stream_id_key_func)
 
             with self._lock:
                 for (
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 211437cfaa..8d8894d1d5 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,9 +20,11 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
+from synapse.util.cancellation import cancellable
 
 logger = logging.getLogger(__name__)
 
@@ -58,6 +60,8 @@ class PartialStateEventsTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialStateEventsTracker.await_full_state")
+    @cancellable
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -151,6 +155,8 @@ class PartialCurrentStateTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialCurrentStateTracker.await_full_state")
+    @cancellable
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.
@@ -166,6 +172,7 @@ class PartialCurrentStateTracker:
             logger.info(
                 "Awaiting un-partial-stating of room %s",
                 room_id,
+                stack_info=True,
             )
 
             await make_deferred_yieldable(d)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 54e0b1a23b..bcd840bd88 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
 from synapse.types import StreamToken
 
@@ -69,6 +70,7 @@ class EventSources:
         )
         return token
 
+    @trace
     async def get_current_token_for_pagination(self, room_id: str) -> StreamToken:
         """Get the current token for a given room to be used to paginate
         events.
diff --git a/synapse/types.py b/synapse/types.py
index 668d48d646..ec44601f54 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -52,6 +52,7 @@ from twisted.internet.interfaces import (
 )
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
@@ -699,7 +700,11 @@ class StreamToken:
     START: ClassVar["StreamToken"]
 
     @classmethod
+    @cancellable
     async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
+        """
+        Creates a RoomStreamToken from its textual representation.
+        """
         try:
             keys = string.split(cls._SEPARATOR)
             while len(keys) < len(attr.fields(cls)):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 42f6abb5e1..f7c3a6794e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -20,9 +20,11 @@ from sys import intern
 from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar
 
 import attr
+from prometheus_client import REGISTRY
 from prometheus_client.core import Gauge
 
 from synapse.config.cache import add_resizable_cache
+from synapse.util.metrics import DynamicCollectorRegistry
 
 logger = logging.getLogger(__name__)
 
@@ -30,27 +32,62 @@ logger = logging.getLogger(__name__)
 # Whether to track estimated memory usage of the LruCaches.
 TRACK_MEMORY_USAGE = False
 
+# We track cache metrics in a special registry that lets us update the metrics
+# just before they are returned from the scrape endpoint.
+CACHE_METRIC_REGISTRY = DynamicCollectorRegistry()
 
 caches_by_name: Dict[str, Sized] = {}
-collectors_by_name: Dict[str, "CacheMetric"] = {}
 
-cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
-cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
-cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
-cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
+cache_size = Gauge(
+    "synapse_util_caches_cache_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_hits = Gauge(
+    "synapse_util_caches_cache_hits", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_evicted = Gauge(
+    "synapse_util_caches_cache_evicted_size",
+    "",
+    ["name", "reason"],
+    registry=CACHE_METRIC_REGISTRY,
+)
+cache_total = Gauge(
+    "synapse_util_caches_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_max_size = Gauge(
+    "synapse_util_caches_cache_max_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
 cache_memory_usage = Gauge(
     "synapse_util_caches_cache_size_bytes",
     "Estimated memory usage of the caches",
     ["name"],
+    registry=CACHE_METRIC_REGISTRY,
 )
 
-response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
-response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_size = Gauge(
+    "synapse_util_caches_response_cache_size",
+    "",
+    ["name"],
+    registry=CACHE_METRIC_REGISTRY,
+)
+response_cache_hits = Gauge(
+    "synapse_util_caches_response_cache_hits",
+    "",
+    ["name"],
+    registry=CACHE_METRIC_REGISTRY,
+)
 response_cache_evicted = Gauge(
-    "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
+    "synapse_util_caches_response_cache_evicted_size",
+    "",
+    ["name", "reason"],
+    registry=CACHE_METRIC_REGISTRY,
 )
-response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+response_cache_total = Gauge(
+    "synapse_util_caches_response_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+
+
+# Register our custom cache metrics registry with the global registry
+REGISTRY.register(CACHE_METRIC_REGISTRY)
 
 
 class EvictionReason(Enum):
@@ -170,7 +207,7 @@ def register_cache(
     metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
     metric_name = "cache_%s_%s" % (cache_type, cache_name)
     caches_by_name[cache_name] = cache
-    collectors_by_name[metric_name] = metric
+    CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect)
     return metric
 
 
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import enum
 import threading
 from typing import (
     Callable,
+    Collection,
+    Dict,
     Generic,
-    Iterable,
     MutableMapping,
     Optional,
+    Set,
     Sized,
+    Tuple,
     TypeVar,
     Union,
     cast,
@@ -31,7 +35,6 @@ from typing import (
 from prometheus_client import Gauge
 
 from twisted.internet import defer
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache: Union[
-            TreeCache, "MutableMapping[KT, CacheEntry]"
+            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
         ] = cache_type()
 
         def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
         Raises:
             KeyError if the key is not found in the cache
         """
-        callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
-            val.callbacks.update(callbacks)
+            val.add_invalidation_callback(key, callback)
             if update_metrics:
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred.observe()
+            return val.deferred(key)
+
+        callbacks = (callback,) if callback else ()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
         else:
             return defer.succeed(val2)
 
+    def get_bulk(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+        """Bulk lookup of items in the cache.
+
+        Returns:
+            A 3-tuple of:
+                1. a dict of key/value of items already cached;
+                2. a deferred that resolves to a dict of key/value of items
+                   we're already fetching; and
+                3. a collection of keys that don't appear in the previous two.
+        """
+
+        # The cached results
+        cached = {}
+
+        # List of pending deferreds
+        pending = []
+
+        # Dict that gets filled out when the pending deferreds complete
+        pending_results = {}
+
+        # List of keys that aren't in either cache
+        missing = []
+
+        callbacks = (callback,) if callback else ()
+
+        for key in keys:
+            # Check if its in the main cache.
+            immediate_value = self.cache.get(
+                key,
+                _Sentinel.sentinel,
+                callbacks=callbacks,
+            )
+            if immediate_value is not _Sentinel.sentinel:
+                cached[key] = immediate_value
+                continue
+
+            # Check if its in the pending cache
+            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+            if pending_value is not _Sentinel.sentinel:
+                pending_value.add_invalidation_callback(key, callback)
+
+                def completed_cb(value: VT, key: KT) -> VT:
+                    pending_results[key] = value
+                    return value
+
+                # Add a callback to fill out `pending_results` when that completes
+                d = pending_value.deferred(key).addCallback(completed_cb, key)
+                pending.append(d)
+                continue
+
+            # Not in either cache
+            missing.append(key)
+
+        # If we've got pending deferreds, squash them into a single one that
+        # returns `pending_results`.
+        pending_deferred = None
+        if pending:
+            pending_deferred = defer.gatherResults(
+                pending, consumeErrors=True
+            ).addCallback(lambda _: pending_results)
+
+        return (cached, pending_deferred, missing)
+
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
     ) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
             value: a deferred which will complete with a result to add to the cache
             callback: An optional callback to be called when the entry is invalidated
         """
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
         self.check_thread()
 
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
+        self._pending_deferred_cache.pop(key, None)
 
         # XXX: why don't we invalidate the entry in `self.cache` yet?
 
-        # we can save a whole load of effort if the deferred is ready.
-        if value.called:
-            result = value.result
-            if not isinstance(result, failure.Failure):
-                self.cache.set(key, cast(VT, result), callbacks)
-            return value
-
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
         # and add callbacks to add it to the cache properly later.
+        entry = CacheEntrySingle[KT, VT](value)
+        entry.add_invalidation_callback(key, callback)
+        self._pending_deferred_cache[key] = entry
+        deferred = entry.deferred(key).addCallbacks(
+            self._completed_callback,
+            self._error_callback,
+            callbackArgs=(entry, key),
+            errbackArgs=(entry, key),
+        )
 
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+        # we return a new Deferred which will be called before any subsequent observers.
+        return deferred
 
-        self._pending_deferred_cache[key] = entry
+    def start_bulk_input(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> "CacheMultipleEntries[KT, VT]":
+        """Bulk set API for use when fetching multiple keys at once from the DB.
 
-        def compare_and_pop() -> bool:
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result: VT) -> None:
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail: Failure) -> None:
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
+        Called *before* starting the fetch from the DB, and the caller *must*
+        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+        """
 
-        # we return a new Deferred which will be called before any subsequent observers.
-        return observable.observe()
+        entry = CacheMultipleEntries[KT, VT]()
+        entry.add_global_invalidation_callback(callback)
+
+        for key in keys:
+            self._pending_deferred_cache[key] = entry
+
+        return entry
+
+    def _completed_callback(
+        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+    ) -> VT:
+        """Called when a deferred is completed."""
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return value
+
+        self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+        return value
+
+    def _error_callback(
+        self,
+        failure: Failure,
+        entry: "CacheEntry[KT, VT]",
+        key: KT,
+    ) -> Failure:
+        """Called when a deferred errors."""
+
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return failure
+
+        for cb in entry.get_invalidation_callbacks(key):
+            cb()
+
+        return failure
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
     ) -> None:
-        callbacks = [callback] if callback else []
+        callbacks = (callback,) if callback else ()
         self.cache.set(key, value, callbacks=callbacks)
+        self._pending_deferred_cache.pop(key, None)
 
     def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
+        # _pending_deferred_cache, which will (a) stop it being returned for
+        # future queries and (b) stop it being persisted as a proper entry
         # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
         if entry:
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
             for entry in iterate_tree_cache_entry(entry):
-                entry.invalidate()
+                for cb in entry.get_invalidation_callbacks(key):
+                    cb()
 
     def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
+        for key, entry in self._pending_deferred_cache.items():
+            for cb in entry.get_invalidation_callbacks(key):
+                cb()
+
         self._pending_deferred_cache.clear()
 
 
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+    """Abstract class for entries in `DeferredCache[KT, VT]`"""
 
-    def __init__(
-        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
-    ):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self) -> None:
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
+    @abc.abstractmethod
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        """Get a deferred that a caller can wait on to get the value at the
+        given key"""
+        ...
+
+    @abc.abstractmethod
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add an invalidation callback"""
+        ...
+
+    @abc.abstractmethod
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        """Get all invalidation callbacks"""
+        ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+    """An implementation of `CacheEntry` wrapping a deferred that results in a
+    single cache entry.
+    """
+
+    __slots__ = ["_deferred", "_callbacks"]
+
+    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+        self._callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        return self._deferred.observe()
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+    """Cache entry that is used for bulk lookups and insertions."""
+
+    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+    def __init__(self) -> None:
+        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+        self._global_callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        if not self._deferred:
+            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+        return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.setdefault(key, set()).add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks.get(key, set()) | self._global_callbacks
+
+    def add_global_invalidation_callback(
+        self, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add a callback for when any keys get invalidated."""
+        if callback is None:
+            return
+
+        self._global_callbacks.add(callback)
+
+    def complete_bulk(
+        self,
+        cache: DeferredCache[KT, VT],
+        result: Dict[KT, VT],
+    ) -> None:
+        """Called when there is a result"""
+        for key, value in result.items():
+            cache._completed_callback(value, self, key)
+
+        if self._deferred:
+            self._deferred.callback(result)
+
+    def error_bulk(
+        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+    ) -> None:
+        """Called when bulk lookup failed."""
+        for key in keys:
+            cache._error_callback(failure, self, key)
+
+        if self._deferred:
+            self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +571,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +582,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +628,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
+from typing_extensions import Literal
 
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 logger = logging.getLogger(__name__)
 
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DictionaryEntry:  # should be: Generic[DKT, DV].
     """Returned when getting an entry from the cache
 
+    If `full` is true then `known_absent` will be the empty set.
+
     Attributes:
         full: Whether the cache has the full or dict or just some keys.
             If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
         return len(self.value)
 
 
+class _FullCacheKey(enum.Enum):
+    """The key we use to cache the full dict."""
+
+    KEY = object()
+
+
 class _Sentinel(enum.Enum):
     # defining a sentinel in this way allows mypy to correctly handle the
     # type of a dictionary lookup.
     sentinel = object()
 
 
+class _PerKeyValue(Generic[DV]):
+    """The cached value of a dictionary key. If `value` is the sentinel,
+    indicates that the requested key is known to *not* be in the full dict.
+    """
+
+    __slots__ = ["value"]
+
+    def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+        self.value = value
+
+    def __len__(self) -> int:
+        # We add a `__len__` implementation as we use this class in a cache
+        # where the values are variable length.
+        return 1
+
+
 class DictionaryCache(Generic[KT, DKT, DV]):
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
+
+    This cache has two levels of key. First there is the "cache key" (of type
+    `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+    type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+    example, it might look like:
+
+       {
+           1: { 1: "a", 2: "b" },
+           2: { 1: "c" },
+       }
+
+    It is possible to look up either individual dict keys, or the *complete*
+    dict for a given cache key.
+
+    Each dict item, and the complete dict is treated as a separate LRU
+    entry for the purpose of cache expiry. For example, given:
+        dict_cache.get(1, None)  -> DictionaryEntry({1: "a", 2: "b"})
+        dict_cache.get(1, [1])  -> DictionaryEntry({1: "a"})
+        dict_cache.get(1, [2])  -> DictionaryEntry({2: "b"})
+
+    ... then the cache entry for the complete dict will expire first,
+    followed by the cache entry for the '1' dict key, and finally that
+    for the '2' dict key.
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
-            max_size=max_entries, cache_name=name, size_callback=len
+        # We use a single LruCache to store two different types of entries:
+        #   1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+        #      the key doesn't exist in the dict); and
+        #   2. Map from (key, _FullCacheKey.KEY) -> full dict.
+        #
+        # The former is used when explicit keys of the dictionary are looked up,
+        # and the latter when the full dictionary is requested.
+        #
+        # If when explicit keys are requested and not in the cache, we then look
+        # to see if we have the full dict and use that if we do. If found in the
+        # full dict each key is added into the cache.
+        #
+        # This set up allows the `LruCache` to prune the full dict entries if
+        # they haven't been used in a while, even when there have been recent
+        # queries for subsets of the dict.
+        #
+        # Typing:
+        #     * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+        #     * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+        self.cache: LruCache[
+            Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+            Union[_PerKeyValue, Dict[DKT, DV]],
+        ] = LruCache(
+            max_size=max_entries,
+            cache_name=name,
+            cache_type=TreeCache,
+            size_callback=len,
         )
 
         self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         Args:
             key
             dict_keys: If given a set of keys then return only those keys
-                that exist in the cache.
+                that exist in the cache. If None then returns the full dict
+                if it is in the cache.
 
         Returns:
-            DictionaryEntry
+            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+            will contain include the keys that are in the cache. If None then
+            will either return the full dict if in the cache, or the empty
+            dict (with `full` set to False) if it isn't.
         """
-        entry = self.cache.get(key, _Sentinel.sentinel)
-        if entry is not _Sentinel.sentinel:
-            if dict_keys is None:
-                return DictionaryEntry(
-                    entry.full, entry.known_absent, dict(entry.value)
-                )
+        if dict_keys is None:
+            # The caller wants the full set of dictionary keys for this cache key
+            return self._get_full_dict(key)
+
+        # We are being asked for a subset of keys.
+
+        # First go and check for each requested dict key in the cache, tracking
+        # which we couldn't find.
+        values = {}
+        known_absent = set()
+        missing = []
+        for dict_key in dict_keys:
+            entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+            if entry is _Sentinel.sentinel:
+                missing.append(dict_key)
+                continue
+
+            assert isinstance(entry, _PerKeyValue)
+
+            if entry.value is _Sentinel.sentinel:
+                known_absent.add(dict_key)
             else:
-                return DictionaryEntry(
-                    entry.full,
-                    entry.known_absent,
-                    {k: entry.value[k] for k in dict_keys if k in entry.value},
-                )
+                values[dict_key] = entry.value
+
+        # If we found everything we can return immediately.
+        if not missing:
+            return DictionaryEntry(False, known_absent, values)
+
+        # We are missing some keys, so check if we happen to have the full dict in
+        # the cache.
+        #
+        # We don't update the last access time for this cache fetch, as we
+        # aren't explicitly interested in the full dict and so we don't want
+        # requests for explicit dict keys to keep the full dict in the cache.
+        entry = self.cache.get(
+            (key, _FullCacheKey.KEY),
+            _Sentinel.sentinel,
+            update_last_access=False,
+        )
+        if entry is _Sentinel.sentinel:
+            # Not in the cache, return the subset of keys we found.
+            return DictionaryEntry(False, known_absent, values)
+
+        # We have the full dict!
+        assert isinstance(entry, dict)
+
+        for dict_key in missing:
+            # We explicitly add each dict key to the cache, so that cache hit
+            # rates and LRU times for each key can be tracked separately.
+            value = entry.get(dict_key, _Sentinel.sentinel)  # type: ignore[arg-type]
+            self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+            if value is not _Sentinel.sentinel:
+                values[dict_key] = value
+
+        return DictionaryEntry(True, set(), values)
+
+    def _get_full_dict(
+        self,
+        key: KT,
+    ) -> DictionaryEntry:
+        """Fetch the full dict for the given key."""
+
+        # First we check if we have cached the full dict.
+        entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
+            assert isinstance(entry, dict)
+            return DictionaryEntry(True, set(), entry)
 
         return DictionaryEntry(False, set(), {})
 
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(key, None)
+
+        # We want to drop all information about the dict for the given key, so
+        # we use `del_multi` to delete it all in one go.
+        #
+        # We ignore the type error here: `del_multi` accepts a truncated key
+        # (when the key type is a tuple).
+        self.cache.del_multi((key,))  # type: ignore[arg-type]
 
     def invalidate_all(self) -> None:
         self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         value: Dict[DKT, DV],
         fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
-        """Updates the entry in the cache
+        """Updates the entry in the cache.
+
+        Note: This does *not* invalidate any existing entries for the `key`.
+        In particular, if we add an entry for the cached "full dict" with
+        `fetched_keys=None`, existing entries for individual dict keys are
+        not invalidated. Likewise, adding entries for individual keys does
+        not invalidate any cached value for the full dict.
+
+        In other words: if the underlying data is *changed*, the cache must
+        be explicitly invalidated via `.invalidate()`.
 
         Args:
             sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
             if fetched_keys is None:
-                self._insert(key, value, set())
+                self.cache[(key, _FullCacheKey.KEY)] = value
             else:
-                self._update_or_insert(key, value, fetched_keys)
+                self._update_subset(key, value, fetched_keys)
 
-    def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+    def _update_subset(
+        self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
     ) -> None:
-        # We pop and reinsert as we need to tell the cache the size may have
-        # changed
+        """Add the given dictionary values as explicit keys in the cache.
+
+        Args:
+            key: top-level cache key
+            value: The dictionary with all the values that we should cache
+            fetched_keys: The full set of dict keys that were looked up. Any keys
+                here not in `value` should be marked as "known absent".
+        """
+
+        for dict_key, dict_value in value.items():
+            self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
 
-        entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
-        entry.value.update(value)
-        entry.known_absent.update(known_absent)
-        self.cache[key] = entry
+        for dict_key in fetched_keys:
+            if dict_key in value:
+                continue
 
-    def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
-        self.cache[key] = DictionaryEntry(True, known_absent, value)
+            self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 8ed5325c5d..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     List,
     Optional,
+    Tuple,
     Type,
     TypeVar,
     Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+    TreeCache,
+    iterate_tree_cache_entry,
+    iterate_tree_cache_items,
+)
 from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
             default: Literal[None] = None,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Optional[VT]:
             ...
 
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
             default: T,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Union[T, VT]:
             ...
 
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
+            update_last_access: bool = True,
         ) -> Union[None, T, VT]:
+            """Look up a key in the cache
+
+            Args:
+                key
+                default
+                callbacks: A collection of callbacks that will fire when the
+                    node is removed from the cache (either due to invalidation
+                    or expiry).
+                update_metrics: Whether to update the hit rate metrics
+                update_last_access: Whether to update the last access metrics
+                    on a node if successfully fetched. These metrics are used
+                    to determine when to remove the node from the cache. Set
+                    to False if this fetch should *not* prevent a node from
+                    being expired.
+            """
             node = cache.get(key, None)
             if node is not None:
-                move_node_to_front(node)
+                if update_last_access:
+                    move_node_to_front(node)
                 node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
                     metrics.inc_misses()
                 return default
 
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: Literal[None] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: T,
+            update_metrics: bool = True,
+        ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @synchronized
+        def cache_get_multi(
+            key: tuple,
+            default: Optional[T] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+            """Returns a generator yielding all entries under the given key.
+
+            Can only be used if backed by a tree cache.
+
+            Example:
+
+                cache = LruCache(10, cache_type=TreeCache)
+                cache[(1, 1)] = "a"
+                cache[(1, 2)] = "b"
+                cache[(2, 1)] = "c"
+
+                items = cache.get_multi((1,))
+                assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+            Returns:
+                Either default if the key doesn't exist, or a generator of the
+                key/value pairs.
+            """
+
+            assert isinstance(cache, TreeCache)
+
+            node = cache.get(key, None)
+            if node is not None:
+                if update_metrics and metrics:
+                    metrics.inc_hits()
+
+                # We store entries in the `TreeCache` with values of type `_Node`,
+                # which we need to unwrap.
+                return (
+                    (full_key, lru_node.value)
+                    for full_key, lru_node in iterate_tree_cache_items(key, node)
+                )
+            else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
+                return default
+
         @synchronized
         def cache_set(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
         self.setdefault = cache_set_default
         self.pop = cache_pop
         self.del_multi = cache_del_multi
+        if cache_type is TreeCache:
+            self.get_multi = cache_get_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
         self.invalidate = cache_del_multi
@@ -730,3 +816,58 @@ class LruCache(Generic[KT, VT]):
         # This happens e.g. in the sync code where we have an expiring cache of
         # lru caches.
         self.clear()
+
+
+class AsyncLruCache(Generic[KT, VT]):
+    """
+    An asynchronous wrapper around a subset of the LruCache API.
+
+    On its own this doesn't change the behaviour but allows subclasses that
+    utilize external cache systems that require await behaviour to be created.
+    """
+
+    def __init__(self, *args, **kwargs):  # type: ignore
+        self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
+
+    async def get(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def get_external(
+        self,
+        key: KT,
+        default: Optional[T] = None,
+        update_metrics: bool = True,
+    ) -> Optional[VT]:
+        # This method should fetch from any configured external cache, in this case noop.
+        return None
+
+    def get_local(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def set(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    def set_local(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    async def invalidate(self, key: KT) -> None:
+        # This method should invalidate any external cache and then invalidate the LruCache.
+        return self._lru_cache.invalidate(key)
+
+    def invalidate_local(self, key: KT) -> None:
+        """Remove an entry from the local cache
+
+        This variant of `invalidate` is useful if we know that the external
+        cache has already been invalidated.
+        """
+        return self._lru_cache.invalidate(key)
+
+    async def contains(self, key: KT) -> bool:
+        return self._lru_cache.contains(key)
+
+    async def clear(self) -> None:
+        self._lru_cache.clear()
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
         self.size += 1
 
     def get(self, key, default=None):
+        """When `key` is a full key, fetches the value for the given key (if
+        any).
+
+        If `key` is only a partial key (i.e. a truncated tuple) then returns a
+        `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+        functions to iterate over all entries in the cache with keys that start
+        with the given partial key.
+        """
+
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
@@ -126,6 +135,9 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
+    def items(self):
+        return iterate_tree_cache_items((), self.root)
+
     def __len__(self) -> int:
         return self.size
 
@@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d):
             yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
+
+
+def iterate_tree_cache_items(key, value):
+    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+    can contain dicts.
+
+    The provided key is a tuple that will get prepended to the returned keys.
+
+    Example:
+
+        cache = TreeCache()
+        cache[(1, 1)] = "a"
+        cache[(1, 2)] = "b"
+        cache[(2, 1)] = "c"
+
+        tree_node = cache.get((1,))
+
+        items = iterate_tree_cache_items((1,), tree_node)
+        assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+    Returns:
+        A generator yielding key/value pairs.
+    """
+    if isinstance(value, TreeCacheNode):
+        for sub_key, sub_value in value.items():
+            yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+    else:
+        # we've reached a leaf of the tree.
+        yield key, value
diff --git a/synapse/util/cancellation.py b/synapse/util/cancellation.py
new file mode 100644
index 0000000000..472d2e3aeb
--- /dev/null
+++ b/synapse/util/cancellation.py
@@ -0,0 +1,56 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, TypeVar
+
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def cancellable(function: F) -> F:
+    """Marks a function as cancellable.
+
+    Servlet methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    Although this annotation is particularly useful for servlet methods, it's also
+    useful for intermediate functions, where it documents the fact that the function has
+    been audited for cancellation safety and needs to preserve that.
+    This then simplifies auditing new functions that call those same intermediate
+    functions.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new function, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    See the documentation page on Cancellation for more information.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+
+    function.cancellable = True  # type: ignore[attr-defined]
+    return function
+
+
+def is_function_cancellable(function: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(function, "cancellable", False)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index bc3b4938ea..165480bdbe 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,9 +15,9 @@
 import logging
 from functools import wraps
 from types import TracebackType
-from typing import Awaitable, Callable, Optional, Type, TypeVar
+from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar
 
-from prometheus_client import Counter
+from prometheus_client import CollectorRegistry, Counter, Metric
 from typing_extensions import Concatenate, ParamSpec, Protocol
 
 from synapse.logging.context import (
@@ -208,3 +208,33 @@ class Measure:
         metrics.real_time_sum += duration
 
         # TODO: Add other in flight metrics.
+
+
+class DynamicCollectorRegistry(CollectorRegistry):
+    """
+    Custom Prometheus Collector registry that calls a hook first, allowing you
+    to update metrics on-demand.
+
+    Don't forget to register this registry with the main registry!
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self._pre_update_hooks: Dict[str, Callable[[], None]] = {}
+
+    def collect(self) -> Generator[Metric, None, None]:
+        """
+        Collects metrics, calling pre-update hooks first.
+        """
+
+        for pre_update_hook in self._pre_update_hooks.values():
+            pre_update_hook()
+
+        yield from super().collect()
+
+    def register_hook(self, metric_name: str, hook: Callable[[], None]) -> None:
+        """
+        Registers a hook that is called before metric collection.
+        """
+
+        self._pre_update_hooks[metric_name] = hook
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index dfe628c97e..9f64fed0d7 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -15,18 +15,35 @@
 import collections
 import contextlib
 import logging
+import threading
 import typing
-from typing import Any, DefaultDict, Iterator, List, Set
+from typing import (
+    Any,
+    Callable,
+    DefaultDict,
+    Dict,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
+
+from prometheus_client.core import Counter
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 
 from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
 from synapse.util import Clock
 
 if typing.TYPE_CHECKING:
@@ -35,15 +52,127 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter(
+    "synapse_rate_limit_sleep",
+    "Number of requests slept by the rate limiter",
+    ["rate_limiter_name"],
+)
+rate_limit_reject_counter = Counter(
+    "synapse_rate_limit_reject",
+    "Number of requests rejected by the rate limiter",
+    ["rate_limiter_name"],
+)
+queue_wait_timer = Histogram(
+    "synapse_rate_limit_queue_wait_time_seconds",
+    "Amount of time spent waiting for the rate limiter to let our request through.",
+    ["rate_limiter_name"],
+    buckets=(
+        0.005,
+        0.01,
+        0.025,
+        0.05,
+        0.1,
+        0.25,
+        0.5,
+        0.75,
+        1.0,
+        2.5,
+        5.0,
+        10.0,
+        20.0,
+        "+Inf",
+    ),
+)
+
+
+_rate_limiter_instances: Set["FederationRateLimiter"] = set()
+# Protects the _rate_limiter_instances set from concurrent access
+_rate_limiter_instances_lock = threading.Lock()
+
+
+def _get_counts_from_rate_limiter_instance(
+    count_func: Callable[["FederationRateLimiter"], int]
+) -> Mapping[Tuple[str, ...], int]:
+    """Returns a count of something (slept/rejected hosts) by (metrics_name)"""
+    # Cast to a list to prevent it changing while the Prometheus
+    # thread is collecting metrics
+    with _rate_limiter_instances_lock:
+        rate_limiter_instances = list(_rate_limiter_instances)
+
+    # Map from (metrics_name,) -> int, the number of something like slept hosts
+    # or rejected hosts. The key type is Tuple[str], but we leave the length
+    # unspecified for compatability with LaterGauge's annotations.
+    counts: Dict[Tuple[str, ...], int] = {}
+    for rate_limiter_instance in rate_limiter_instances:
+        # Only track metrics if they provided a `metrics_name` to
+        # differentiate this instance of the rate limiter.
+        if rate_limiter_instance.metrics_name:
+            key = (rate_limiter_instance.metrics_name,)
+            counts[key] = count_func(rate_limiter_instance)
+
+    return counts
+
+
+# We track the number of affected hosts per time-period so we can
+# differentiate one really noisy homeserver from a general
+# ratelimit tuning problem across the federation.
+LaterGauge(
+    "synapse_rate_limit_sleep_affected_hosts",
+    "Number of hosts that had requests put to sleep",
+    ["rate_limiter_name"],
+    lambda: _get_counts_from_rate_limiter_instance(
+        lambda rate_limiter_instance: sum(
+            ratelimiter.should_sleep()
+            for ratelimiter in rate_limiter_instance.ratelimiters.values()
+        )
+    ),
+)
+LaterGauge(
+    "synapse_rate_limit_reject_affected_hosts",
+    "Number of hosts that had requests rejected",
+    ["rate_limiter_name"],
+    lambda: _get_counts_from_rate_limiter_instance(
+        lambda rate_limiter_instance: sum(
+            ratelimiter.should_reject()
+            for ratelimiter in rate_limiter_instance.ratelimiters.values()
+        )
+    ),
+)
+
+
 class FederationRateLimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    """Used to rate limit request per-host."""
+
+    def __init__(
+        self,
+        clock: Clock,
+        config: FederationRatelimitSettings,
+        metrics_name: Optional[str] = None,
+    ):
+        """
+        Args:
+            clock
+            config
+            metrics_name: The name of the rate limiter so we can differentiate it
+                from the rest in the metrics. If `None`, we don't track metrics
+                for this rate limiter.
+
+        """
+        self.metrics_name = metrics_name
+
         def new_limiter() -> "_PerHostRatelimiter":
-            return _PerHostRatelimiter(clock=clock, config=config)
+            return _PerHostRatelimiter(
+                clock=clock, config=config, metrics_name=metrics_name
+            )
 
         self.ratelimiters: DefaultDict[
             str, "_PerHostRatelimiter"
         ] = collections.defaultdict(new_limiter)
 
+        with _rate_limiter_instances_lock:
+            _rate_limiter_instances.add(self)
+
     def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
         """Used to ratelimit an incoming request from a given host
 
@@ -59,17 +188,27 @@ class FederationRateLimiter:
         Returns:
             context manager which returns a deferred.
         """
-        return self.ratelimiters[host].ratelimit()
+        return self.ratelimiters[host].ratelimit(host)
 
 
 class _PerHostRatelimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    def __init__(
+        self,
+        clock: Clock,
+        config: FederationRatelimitSettings,
+        metrics_name: Optional[str] = None,
+    ):
         """
         Args:
             clock
             config
+            metrics_name: The name of the rate limiter so we can differentiate it
+                from the rest in the metrics. If `None`, we don't track metrics
+                for this rate limiter.
+                from the rest in the metrics
         """
         self.clock = clock
+        self.metrics_name = metrics_name
 
         self.window_size = config.window_size
         self.sleep_limit = config.sleep_limit
@@ -94,19 +233,45 @@ class _PerHostRatelimiter:
         self.request_times: List[int] = []
 
     @contextlib.contextmanager
-    def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+    def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
         # `contextlib.contextmanager` takes a generator and turns it into a
         # context manager. The generator should only yield once with a value
         # to be returned by manager.
         # Exceptions will be reraised at the yield.
 
+        self.host = host
+
         request_id = object()
-        ret = self._on_enter(request_id)
+        # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+        # type-checking, but we'd need Twisted >= 21.2.
+        ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
         try:
             yield ret
         finally:
             self._on_exit(request_id)
 
+    def should_reject(self) -> bool:
+        """
+        Whether to reject the request if we already have too many queued up
+        (either sleeping or in the ready queue).
+        """
+        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+        return queue_size > self.reject_limit
+
+    def should_sleep(self) -> bool:
+        """
+        Whether to sleep the request if we already have too many requests coming
+        through within the window.
+        """
+        return len(self.request_times) > self.sleep_limit
+
+    async def _on_enter_with_tracing(self, request_id: object) -> None:
+        maybe_metrics_cm: ContextManager = contextlib.nullcontext()
+        if self.metrics_name:
+            maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
+        with start_active_span("ratelimit wait"), maybe_metrics_cm:
+            await self._on_enter(request_id)
+
     def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
         time_now = self.clock.time_msec()
 
@@ -117,8 +282,10 @@ class _PerHostRatelimiter:
 
         # reject the request if we already have too many queued up (either
         # sleeping or in the ready queue).
-        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
-        if queue_size > self.reject_limit:
+        if self.should_reject():
+            logger.debug("Ratelimiter(%s): rejecting request", self.host)
+            if self.metrics_name:
+                rate_limit_reject_counter.labels(self.metrics_name).inc()
             raise LimitExceededError(
                 retry_after_ms=int(self.window_size / self.sleep_limit)
             )
@@ -130,7 +297,8 @@ class _PerHostRatelimiter:
                 queue_defer: defer.Deferred[None] = defer.Deferred()
                 self.ready_request_queue[request_id] = queue_defer
                 logger.info(
-                    "Ratelimiter: queueing request (queue now %i items)",
+                    "Ratelimiter(%s): queueing request (queue now %i items)",
+                    self.host,
                     len(self.ready_request_queue),
                 )
 
@@ -139,19 +307,29 @@ class _PerHostRatelimiter:
                 return defer.succeed(None)
 
         logger.debug(
-            "Ratelimit [%s]: len(self.request_times)=%d",
+            "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+            self.host,
             id(request_id),
             len(self.request_times),
         )
 
-        if len(self.request_times) > self.sleep_limit:
-            logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+        if self.should_sleep():
+            logger.debug(
+                "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+                self.host,
+                id(request_id),
+                self.sleep_sec,
+            )
+            if self.metrics_name:
+                rate_limit_sleep_counter.labels(self.metrics_name).inc()
             ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
 
             self.sleeping_requests.add(request_id)
 
             def on_wait_finished(_: Any) -> "defer.Deferred[None]":
-                logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+                logger.debug(
+                    "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+                )
                 self.sleeping_requests.discard(request_id)
                 queue_defer = queue_request()
                 return queue_defer
@@ -161,7 +339,9 @@ class _PerHostRatelimiter:
             ret_defer = queue_request()
 
         def on_start(r: object) -> object:
-            logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+            logger.debug(
+                "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+            )
             self.current_processing.add(request_id)
             return r
 
@@ -183,7 +363,7 @@ class _PerHostRatelimiter:
         return make_deferred_yieldable(ret_defer)
 
     def _on_exit(self, request_id: object) -> None:
-        logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+        logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
         self.current_processing.discard(request_id)
         try:
             # start processing the next item on the queue.
diff --git a/synapse/util/rust.py b/synapse/util/rust.py
new file mode 100644
index 0000000000..30ecb9ffd9
--- /dev/null
+++ b/synapse/util/rust.py
@@ -0,0 +1,84 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+from hashlib import blake2b
+
+import synapse
+from synapse.synapse_rust import get_rust_file_digest
+
+
+def check_rust_lib_up_to_date() -> None:
+    """For editable installs check if the rust library is outdated and needs to
+    be rebuilt.
+    """
+
+    if not _dist_is_editable():
+        return
+
+    synapse_dir = os.path.dirname(synapse.__file__)
+    synapse_root = os.path.abspath(os.path.join(synapse_dir, ".."))
+
+    # Double check we've not gone into site-packages...
+    if os.path.basename(synapse_root) == "site-packages":
+        return
+
+    # ... and it looks like the root of a python project.
+    if not os.path.exists("pyproject.toml"):
+        return
+
+    # Get the hash of all Rust source files
+    hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src"))
+
+    if hash != get_rust_file_digest():
+        raise Exception("Rust module outdated. Please rebuild using `poetry install`")
+
+
+def _hash_rust_files_in_directory(directory: str) -> str:
+    """Get the hash of all files in a directory (recursively)"""
+
+    directory = os.path.abspath(directory)
+
+    paths = []
+
+    dirs = [directory]
+    while dirs:
+        dir = dirs.pop()
+        with os.scandir(dir) as d:
+            for entry in d:
+                if entry.is_dir():
+                    dirs.append(entry.path)
+                else:
+                    paths.append(entry.path)
+
+    # We sort to make sure that we get a consistent and well-defined ordering.
+    paths.sort()
+
+    hasher = blake2b()
+
+    for path in paths:
+        with open(os.path.join(directory, path), "rb") as f:
+            hasher.update(f.read())
+
+    return hasher.hexdigest()
+
+
+def _dist_is_editable() -> bool:
+    """Is distribution an editable install?"""
+    for path_item in sys.path:
+        egg_link = os.path.join(path_item, "matrix-synapse.egg-link")
+        if os.path.isfile(egg_link):
+            return True
+    return False
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 9abbaa5a64..c810a05907 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
+from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
 from synapse.storage.state import StateFilter
@@ -51,6 +52,7 @@ MEMBERSHIP_PRIORITY = (
 _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
 
 
+@trace
 async def filter_events_for_client(
     storage: StorageControllers,
     user_id: str,
@@ -71,8 +73,8 @@ async def filter_events_for_client(
           * the user is not currently a member of the room, and:
           * the user has not been a member of the room since the given
             events
-        always_include_ids: set of event ids to specifically
-            include (unless sender is ignored)
+        always_include_ids: set of event ids to specifically include, if present
+            in events (unless sender is ignored)
         filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.