summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/_base.py150
-rw-r--r--synapse/app/generic_worker.py22
-rw-r--r--synapse/app/homeserver.py76
-rw-r--r--synapse/config/_util.py2
-rw-r--r--synapse/config/sso.py27
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/federation/federation_server.py21
-rw-r--r--synapse/handlers/auth.py95
-rw-r--r--synapse/handlers/cas_handler.py38
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/devicemessage.py48
-rw-r--r--synapse/handlers/oidc_handler.py320
-rw-r--r--synapse/handlers/profile.py12
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/saml_handler.py28
-rw-r--r--synapse/handlers/sso.py136
-rw-r--r--synapse/handlers/ui_auth/__init__.py15
-rw-r--r--synapse/http/__init__.py15
-rw-r--r--synapse/http/matrixfederationclient.py10
-rw-r--r--synapse/http/site.py18
-rw-r--r--synapse/logging/context.py50
-rw-r--r--synapse/notifier.py39
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py12
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py40
-rw-r--r--synapse/replication/tcp/handler.py9
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html28
-rw-r--r--synapse/rest/admin/users.py29
-rw-r--r--synapse/rest/client/v1/login.py89
-rw-r--r--synapse/rest/client/v2_alpha/account.py44
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py21
-rw-r--r--synapse/rest/synapse/client/pick_idp.py82
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/static/client/login/style.css5
-rw-r--r--synapse/storage/database.py30
-rw-r--r--synapse/storage/databases/main/__init__.py33
-rw-r--r--synapse/storage/databases/main/account_data.py157
-rw-r--r--synapse/storage/databases/main/client_ips.py54
-rw-r--r--synapse/storage/databases/main/deviceinbox.py252
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py88
-rw-r--r--synapse/storage/databases/main/event_federation.py185
-rw-r--r--synapse/storage/databases/main/events.py584
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py124
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/room.py51
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite62
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py82
-rw-r--r--synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql52
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql17
-rw-r--r--synapse/storage/databases/main/tags.py10
-rw-r--r--synapse/storage/prepare_database.py22
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/iterutils.py53
-rw-r--r--synapse/util/metrics.py11
65 files changed, 2665 insertions, 958 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 48c4d7b0be..67ecbd32ff 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
+from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
@@ -186,8 +187,8 @@ class Auth:
             AuthError if access is denied for the user in the access token
         """
         try:
-            ip_addr = self.hs.get_ip_from_request(request)
-            user_agent = request.get_user_agent("")
+            ip_addr = request.getClientIP()
+            user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
 
@@ -275,7 +276,7 @@ class Auth:
             return None, None
 
         if app_service.ip_range_whitelist:
-            ip_address = IPAddress(self.hs.get_ip_from_request(request))
+            ip_address = IPAddress(request.getClientIP())
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f3ecbf36b6..de2cc15d33 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -51,11 +51,11 @@ class RoomDisposition:
 class RoomVersion:
     """An object which describes the unique attributes of a room version."""
 
-    identifier = attr.ib()  # str; the identifier for this version
-    disposition = attr.ib()  # str; one of the RoomDispositions
-    event_format = attr.ib()  # int; one of the EventFormatVersions
-    state_res = attr.ib()  # int; one of the StateResolutionVersions
-    enforce_key_validity = attr.ib()  # bool
+    identifier = attr.ib(type=str)  # the identifier for this version
+    disposition = attr.ib(type=str)  # one of the RoomDispositions
+    event_format = attr.ib(type=int)  # one of the EventFormatVersions
+    state_res = attr.ib(type=int)  # one of the StateResolutionVersions
+    enforce_key_validity = attr.ib(type=bool)
 
     # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
     special_case_aliases_auth = attr.ib(type=bool)
@@ -64,9 +64,11 @@ class RoomVersion:
     # * Floats
     # * NaN, Infinity, -Infinity
     strict_canonicaljson = attr.ib(type=bool)
-    # bool: MSC2209: Check 'notifications' key while verifying
+    # MSC2209: Check 'notifications' key while verifying
     # m.room.power_levels auth rules.
     limit_notifications_power_levels = attr.ib(type=bool)
+    # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+    msc2176_redaction_rules = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -79,6 +81,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V2 = RoomVersion(
         "2",
@@ -89,6 +92,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V3 = RoomVersion(
         "3",
@@ -99,6 +103,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V4 = RoomVersion(
         "4",
@@ -109,6 +114,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V5 = RoomVersion(
         "5",
@@ -119,6 +125,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V6 = RoomVersion(
         "6",
@@ -129,6 +136,18 @@ class RoomVersions:
         special_case_aliases_auth=False,
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+    )
+    MSC2176 = RoomVersion(
+        "org.matrix.msc2176",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=True,
     )
 
 
@@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V4,
         RoomVersions.V5,
         RoomVersions.V6,
+        RoomVersions.MSC2176,
     )
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 37ecdbe3d8..395e202b89 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,7 +20,7 @@ import signal
 import socket
 import sys
 import traceback
-from typing import Iterable
+from typing import Awaitable, Callable, Iterable
 
 from typing_extensions import NoReturn
 
@@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
     sys.exit(1)
 
 
+def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+    """Register a callback with the reactor, to be called once it is running
+
+    This can be used to initialise parts of the system which require an asynchronous
+    setup.
+
+    Any exception raised by the callback will be printed and logged, and the process
+    will exit.
+    """
+
+    async def wrapper():
+        try:
+            await cb(*args, **kwargs)
+        except Exception:
+            # previously, we used Failure().printTraceback() here, in the hope that
+            # would give better tracebacks than traceback.print_exc(). However, that
+            # doesn't handle chained exceptions (with a __cause__ or __context__) well,
+            # and I *think* the need for Failure() is reduced now that we mostly use
+            # async/await.
+
+            # Write the exception to both the logs *and* the unredirected stderr,
+            # because people tend to get confused if it only goes to one or the other.
+            #
+            # One problem with this is that if people are using a logging config that
+            # logs to the console (as is common eg under docker), they will get two
+            # copies of the exception. We could maybe try to detect that, but it's
+            # probably a cost we can bear.
+            logger.fatal("Error during startup", exc_info=True)
+            print("Error during startup:", file=sys.__stderr__)
+            traceback.print_exc(file=sys.__stderr__)
+
+            # it's no use calling sys.exit here, since that just raises a SystemExit
+            # exception which is then caught by the reactor, and everything carries
+            # on as normal.
+            os._exit(1)
+
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
+
+
 def listen_metrics(bind_addresses, port):
     """
     Start Prometheus metrics server.
@@ -227,7 +267,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
     """
     Start a Synapse server or worker.
 
@@ -241,75 +281,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
         hs: homeserver instance
         listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
-    try:
-        # Set up the SIGHUP machinery.
-        if hasattr(signal, "SIGHUP"):
+    # Set up the SIGHUP machinery.
+    if hasattr(signal, "SIGHUP"):
+        reactor = hs.get_reactor()
 
-            reactor = hs.get_reactor()
+        @wrap_as_background_process("sighup")
+        def handle_sighup(*args, **kwargs):
+            # Tell systemd our state, if we're using it. This will silently fail if
+            # we're not using systemd.
+            sdnotify(b"RELOADING=1")
 
-            @wrap_as_background_process("sighup")
-            def handle_sighup(*args, **kwargs):
-                # Tell systemd our state, if we're using it. This will silently fail if
-                # we're not using systemd.
-                sdnotify(b"RELOADING=1")
+            for i, args, kwargs in _sighup_callbacks:
+                i(*args, **kwargs)
 
-                for i, args, kwargs in _sighup_callbacks:
-                    i(*args, **kwargs)
+            sdnotify(b"READY=1")
 
-                sdnotify(b"READY=1")
+        # We defer running the sighup handlers until next reactor tick. This
+        # is so that we're in a sane state, e.g. flushing the logs may fail
+        # if the sighup happens in the middle of writing a log entry.
+        def run_sighup(*args, **kwargs):
+            # `callFromThread` should be "signal safe" as well as thread
+            # safe.
+            reactor.callFromThread(handle_sighup, *args, **kwargs)
 
-            # We defer running the sighup handlers until next reactor tick. This
-            # is so that we're in a sane state, e.g. flushing the logs may fail
-            # if the sighup happens in the middle of writing a log entry.
-            def run_sighup(*args, **kwargs):
-                # `callFromThread` should be "signal safe" as well as thread
-                # safe.
-                reactor.callFromThread(handle_sighup, *args, **kwargs)
+        signal.signal(signal.SIGHUP, run_sighup)
 
-            signal.signal(signal.SIGHUP, run_sighup)
+        register_sighup(refresh_certificate, hs)
 
-            register_sighup(refresh_certificate, hs)
+    # Load the certificate from disk.
+    refresh_certificate(hs)
 
-        # Load the certificate from disk.
-        refresh_certificate(hs)
+    # Start the tracer
+    synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
+        hs
+    )
 
-        # Start the tracer
-        synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
-            hs
-        )
+    # It is now safe to start your Synapse.
+    hs.start_listening(listeners)
+    hs.get_datastore().db_pool.start_profiling()
+    hs.get_pusherpool().start()
+
+    # Log when we start the shut down process.
+    hs.get_reactor().addSystemEventTrigger(
+        "before", "shutdown", logger.info, "Shutting down..."
+    )
 
-        # It is now safe to start your Synapse.
-        hs.start_listening(listeners)
-        hs.get_datastore().db_pool.start_profiling()
-        hs.get_pusherpool().start()
+    setup_sentry(hs)
+    setup_sdnotify(hs)
 
-        # Log when we start the shut down process.
-        hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", logger.info, "Shutting down..."
-        )
+    # If background tasks are running on the main process, start collecting the
+    # phone home stats.
+    if hs.config.run_background_tasks:
+        start_phone_stats_home(hs)
 
-        setup_sentry(hs)
-        setup_sdnotify(hs)
-
-        # If background tasks are running on the main process, start collecting the
-        # phone home stats.
-        if hs.config.run_background_tasks:
-            start_phone_stats_home(hs)
-
-        # We now freeze all allocated objects in the hopes that (almost)
-        # everything currently allocated are things that will be used for the
-        # rest of time. Doing so means less work each GC (hopefully).
-        #
-        # This only works on Python 3.7
-        if sys.version_info >= (3, 7):
-            gc.collect()
-            gc.freeze()
-    except Exception:
-        traceback.print_exc(file=sys.stderr)
-        reactor = hs.get_reactor()
-        if reactor.running:
-            reactor.stop()
-        sys.exit(1)
+    # We now freeze all allocated objects in the hopes that (almost)
+    # everything currently allocated are things that will be used for the
+    # rest of time. Doing so means less work each GC (hopefully).
+    #
+    # This only works on Python 3.7
+    if sys.version_info >= (3, 7):
+        gc.collect()
+        gc.freeze()
 
 
 def setup_sentry(hs):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fa23d9bb20..f24c648ac7 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
 
 from typing_extensions import ContextManager
 
-from twisted.internet import address, reactor
+from twisted.internet import address
 
 import synapse
 import synapse.events
@@ -34,6 +34,7 @@ from synapse.api.urls import (
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
+from synapse.app._base import register_start
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
@@ -99,21 +100,27 @@ from synapse.rest.client.v1.profile import (
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
     AccountDataServlet,
     RoomAccountDataServlet,
 )
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.keys import (
+    KeyChangesServlet,
+    KeyQueryServlet,
+    OneTimeKeyServlet,
+)
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
 from synapse.storage.databases.main.metrics import ServerMetricsStore
 from synapse.storage.databases.main.monthly_active_users import (
@@ -445,6 +452,7 @@ class GenericWorkerSlavedStore(
     UserDirectoryStore,
     StatsStore,
     UIAuthWorkerStore,
+    EndToEndRoomKeyStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedReceiptsStore,
@@ -502,6 +510,7 @@ class GenericWorkerServer(HomeServer):
                     LoginRestServlet(self).register(resource)
                     ThreepidRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
+                    OneTimeKeyServlet(self).register(resource)
                     KeyChangesServlet(self).register(resource)
                     VoipRestServlet(self).register(resource)
                     PushRuleRestServlet(self).register(resource)
@@ -519,6 +528,9 @@ class GenericWorkerServer(HomeServer):
                     room.register_servlets(self, resource, True)
                     room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
+                    room_keys.register_servlets(self, resource)
+
+                    SendToDeviceRestServlet(self).register(resource)
 
                     user_directory.register_servlets(self, resource)
 
@@ -957,9 +969,7 @@ def start(config_options):
     # streams. Will no-op if no streams can be written to by this worker.
     hs.get_replication_streamer()
 
-    reactor.addSystemEventTrigger(
-        "before", "startup", _base.start, hs, config.worker_listeners
-    )
+    register_start(_base.start, hs, config.worker_listeners)
 
     _base.start_worker_reactor("synapse-generic-worker", config)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8d9b53be53..cbecf23be6 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,15 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import gc
 import logging
 import os
 import sys
 from typing import Iterable, Iterator
 
-from twisted.application import service
-from twisted.internet import defer, reactor
-from twisted.python.failure import Failure
+from twisted.internet import reactor
 from twisted.web.resource import EncodingResourceWrapper, IResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -40,7 +37,7 @@ from synapse.api.urls import (
     WEB_CLIENT_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
 from synapse.config._base import ConfigError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
@@ -63,6 +60,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
 from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
@@ -72,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.module_loader import load_module
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.homeserver")
@@ -194,6 +191,7 @@ class SynapseHomeServer(HomeServer):
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
                     "/_synapse/client/pick_username": pick_username_resource(self),
+                    "/_synapse/client/pick_idp": PickIdpResource(self),
                 }
             )
 
@@ -415,40 +413,29 @@ def setup(config_options):
             _base.refresh_certificate(hs)
 
     async def start():
-        try:
-            # Run the ACME provisioning code, if it's enabled.
-            if hs.config.acme_enabled:
-                acme = hs.get_acme_handler()
-                # Start up the webservices which we will respond to ACME
-                # challenges with, and then provision.
-                await acme.start_listening()
-                await do_acme()
+        # Run the ACME provisioning code, if it's enabled.
+        if hs.config.acme_enabled:
+            acme = hs.get_acme_handler()
+            # Start up the webservices which we will respond to ACME
+            # challenges with, and then provision.
+            await acme.start_listening()
+            await do_acme()
 
-                # Check if it needs to be reprovisioned every day.
-                hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+            # Check if it needs to be reprovisioned every day.
+            hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
 
-            # Load the OIDC provider metadatas, if OIDC is enabled.
-            if hs.config.oidc_enabled:
-                oidc = hs.get_oidc_handler()
-                # Loading the provider metadata also ensures the provider config is valid.
-                await oidc.load_metadata()
-                await oidc.load_jwks()
+        # Load the OIDC provider metadatas, if OIDC is enabled.
+        if hs.config.oidc_enabled:
+            oidc = hs.get_oidc_handler()
+            # Loading the provider metadata also ensures the provider config is valid.
+            await oidc.load_metadata()
+            await oidc.load_jwks()
 
-            _base.start(hs, config.listeners)
+        await _base.start(hs, config.listeners)
 
-            hs.get_datastore().db_pool.updates.start_doing_background_updates()
-        except Exception:
-            # Print the exception and bail out.
-            print("Error during startup:", file=sys.stderr)
+        hs.get_datastore().db_pool.updates.start_doing_background_updates()
 
-            # this gives better tracebacks than traceback.print_exc()
-            Failure().printTraceback(file=sys.stderr)
-
-            if reactor.running:
-                reactor.stop()
-            sys.exit(1)
-
-    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
+    register_start(start)
 
     return hs
 
@@ -485,25 +472,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
         e = e.__cause__
 
 
-class SynapseService(service.Service):
-    """
-    A twisted Service class that will start synapse. Used to run synapse
-    via twistd and a .tac.
-    """
-
-    def __init__(self, config):
-        self.config = config
-
-    def startService(self):
-        hs = setup(self.config)
-        change_resource_limit(hs.config.soft_file_limit)
-        if hs.config.gc_thresholds:
-            gc.set_threshold(*hs.config.gc_thresholds)
-
-    def stopService(self):
-        return self._port.stopListening()
-
-
 def run(hs):
     PROFILE_SYNAPSE = False
     if PROFILE_SYNAPSE:
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 1bbe83c317..8fce7f6bb1 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -56,7 +56,7 @@ def json_error_to_config_error(
     """
     # copy `config_path` before modifying it.
     path = list(config_path)
-    for p in list(e.path):
+    for p in list(e.absolute_path):
         if isinstance(p, int):
             path.append("<item %i>" % p)
         else:
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 93bbd40937..1aeb1c5c92 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -31,6 +31,7 @@ class SSOConfig(Config):
 
         # Read templates from disk
         (
+            self.sso_login_idp_picker_template,
             self.sso_redirect_confirm_template,
             self.sso_auth_confirm_template,
             self.sso_error_template,
@@ -38,6 +39,7 @@ class SSOConfig(Config):
             sso_auth_success_template,
         ) = self.read_templates(
             [
+                "sso_login_idp_picker.html",
                 "sso_redirect_confirm.html",
                 "sso_auth_confirm.html",
                 "sso_error.html",
@@ -98,6 +100,31 @@ class SSOConfig(Config):
             #
             # Synapse will look for the following templates in this directory:
             #
+            # * HTML page to prompt the user to choose an Identity Provider during
+            #   login: 'sso_login_idp_picker.html'.
+            #
+            #   This is only used if multiple SSO Identity Providers are configured.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * redirect_url: the URL that the user will be redirected to after
+            #       login. Needs manual escaping (see
+            #       https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * providers: a list of available Identity Providers. Each element is
+            #       an object with the following attributes:
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #
+            #   The rendered HTML page should contain a form which submits its results
+            #   back as a GET request, with the following query parameters:
+            #
+            #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+            #       to the template)
+            #
+            #     * idp: the 'idp_id' of the chosen IDP.
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 7ca9efec52..364583f48b 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -53,6 +53,9 @@ class WriterLocations:
         default=["master"], type=List[str], converter=_instance_to_list_converter
     )
     typing = attr.ib(default="master", type=str)
+    to_device = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -124,7 +127,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing"):
+        for stream in ("events", "typing", "to_device"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -133,6 +136,11 @@ class WorkerConfig(Config):
                         % (instance, stream)
                     )
 
+        if len(self.writers.to_device) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `to_device` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 14f7f1156f..9c22e33813 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         "state_key",
         "depth",
         "prev_events",
-        "prev_state",
         "auth_events",
         "origin",
         "origin_server_ts",
-        "membership",
     ]
 
+    # Room versions from before MSC2176 had additional allowed keys.
+    if not room_version.msc2176_redaction_rules:
+        allowed_keys.extend(["prev_state", "membership"])
+
     event_type = event_dict["type"]
 
     new_content = {}
@@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     if event_type == EventTypes.Member:
         add_fields("membership")
     elif event_type == EventTypes.Create:
+        # MSC2176 rules state that create events cannot be redacted.
+        if room_version.msc2176_redaction_rules:
+            return event_dict
+
         add_fields("creator")
     elif event_type == EventTypes.JoinRules:
         add_fields("join_rule")
@@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
             "kick",
             "redact",
         )
+
+        if room_version.msc2176_redaction_rules:
+            add_fields("invite")
+
     elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
         add_fields("aliases")
     elif event_type == EventTypes.RoomHistoryVisibility:
         add_fields("history_visibility")
+    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
+        add_fields("redacts")
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 35e345ce70..e5339aca23 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -860,8 +861,10 @@ class FederationHandlerRegistry:
         )  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
         self.query_handlers = {}  # type: Dict[str, Callable[[dict], Awaitable[None]]]
 
-        # Map from type to instance name that we should route EDU handling to.
-        self._edu_type_to_instance = {}  # type: Dict[str, str]
+        # Map from type to instance names that we should route EDU handling to.
+        # We randomly choose one instance from the list to route to for each new
+        # EDU received.
+        self._edu_type_to_instance = {}  # type: Dict[str, List[str]]
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@@ -905,7 +908,12 @@ class FederationHandlerRegistry:
     def register_instance_for_edu(self, edu_type: str, instance_name: str):
         """Register that the EDU handler is on a different instance than master.
         """
-        self._edu_type_to_instance[edu_type] = instance_name
+        self._edu_type_to_instance[edu_type] = [instance_name]
+
+    def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+        """Register that the EDU handler is on multiple instances.
+        """
+        self._edu_type_to_instance[edu_type] = instance_names
 
     async def on_edu(self, edu_type: str, origin: str, content: dict):
         if not self.config.use_presence and edu_type == "m.presence":
@@ -924,8 +932,11 @@ class FederationHandlerRegistry:
             return
 
         # Check if we can route it somewhere else that isn't us
-        route_to = self._edu_type_to_instance.get(edu_type, "master")
-        if route_to != self._instance_name:
+        instances = self._edu_type_to_instance.get(edu_type, ["master"])
+        if self._instance_name not in instances:
+            # Pick an instance randomly so that we don't overload one.
+            route_to = random.choice(instances)
+
             try:
                 await self._send_edu(
                     instance_name=route_to,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4434673dc..4f881a439a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,8 +49,13 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
@@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -284,7 +287,6 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
     ) -> Tuple[dict, Optional[str]]:
         """
@@ -301,8 +303,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -338,10 +338,10 @@ class AuthHandler(BaseHandler):
                 request_body.pop("auth", None)
                 return request_body, None
 
-        user_id = requester.user.to_string()
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -349,13 +349,16 @@ class AuthHandler(BaseHandler):
         )
         flows = [[login_type] for login_type in supported_ui_auth_types]
 
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
@@ -363,14 +366,14 @@ class AuthHandler(BaseHandler):
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
         # Note that the access token has been validated.
@@ -402,13 +405,9 @@ class AuthHandler(BaseHandler):
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
-        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
-            if await self.store.get_external_ids_by_user(user.to_string()):
-                ui_auth_types.add(LoginType.SSO)
-
-        # Our CAS impl does not (yet) correctly register users in user_external_ids,
-        # so always offer that if it's available.
-        if self.hs.config.cas.cas_enabled:
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
             ui_auth_types.add(LoginType.SSO)
 
         return ui_auth_types
@@ -426,8 +425,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -448,11 +447,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -480,10 +484,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -539,7 +548,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -644,7 +654,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -660,7 +671,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -1334,12 +1346,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1349,6 +1361,35 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
+        )
+
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
+
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
+        )
+
         return self._sso_auth_confirm_template.render(
             description=session.description, redirect_url=redirect_url,
         )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index fca210a5a6..f3430c6713 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -75,10 +75,15 @@ class CasHandler:
         self._http_client = hs.get_proxied_http_client()
 
         # identifier for the external_ids table
-        self._auth_provider_id = "cas"
+        self.idp_id = "cas"
+
+        # user-facing name of this auth provider
+        self.idp_name = "CAS"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -105,7 +110,7 @@ class CasHandler:
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
 
         Raises:
             CasError: If there's an error parsing the CAS response.
@@ -184,16 +189,31 @@ class CasHandler:
 
         return CasResponse(user, attributes)
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -275,7 +295,7 @@ class CasHandler:
         # first check if we're doing a UIA
         if session:
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, cas_response.username, session, request,
+                self.idp_id, cas_response.username, session, request,
             )
 
         # otherwise, we're handling a login request.
@@ -375,7 +395,7 @@ class CasHandler:
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             cas_response.username,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 
 from ._base import BaseHandler
 
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
         self._device_handler = hs.get_device_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._identity_handler = hs.get_identity_handler()
+        self._profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
         self._server_name = hs.hostname
 
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
         self._account_validity_enabled = hs.config.account_validity.enabled
 
     async def deactivate_account(
-        self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+        self,
+        user_id: str,
+        erase_data: bool,
+        requester: Requester,
+        id_server: Optional[str] = None,
+        by_admin: bool = False,
     ) -> bool:
         """Deactivate a user's account
 
         Args:
             user_id: ID of user to be deactivated
             erase_data: whether to GDPR-erase the user's data
+            requester: The user attempting to make this change.
             id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
+            by_admin: Whether this change was made by an administrator.
 
         Returns:
             True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Mark the user as erased, if they asked for that
         if erase_data:
+            user = UserID.from_string(user_id)
+            # Remove avatar URL from this user
+            await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+            # Remove displayname from this user
+            await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
             logger.info("Marking %s as erased", user_id)
             await self.store.mark_user_erased(user_id)
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..fc974a82e8 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
     set_tag,
     start_active_span,
 )
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
@@ -44,13 +45,37 @@ class DeviceMessageHandler:
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
-        self.federation = hs.get_federation_sender()
 
-        hs.get_federation_registry().register_edu_handler(
-            "m.direct_to_device", self.on_direct_to_device_edu
-        )
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the to-device replication stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the to device EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.to_device:
+            hs.get_federation_registry().register_edu_handler(
+                "m.direct_to_device", self.on_direct_to_device_edu
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.direct_to_device", hs.config.worker.writers.to_device,
+            )
 
-        self._device_list_updater = hs.get_device_handler().device_list_updater
+        # The handler to call when we think a user's device list might be out of
+        # sync. We do all device list resyncing on the master instance, so if
+        # we're on a worker we hit the device resync replication API.
+        if hs.config.worker.worker_app is None:
+            self._user_device_resync = (
+                hs.get_device_handler().device_list_updater.user_device_resync
+            )
+        else:
+            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+                hs
+            )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         local_messages = {}
@@ -138,9 +163,7 @@ class DeviceMessageHandler:
             await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
-            run_in_background(
-                self._device_list_updater.user_device_resync, sender_user_id
-            )
+            run_in_background(self._user_device_resync, sender_user_id)
 
     async def send_device_message(
         self,
@@ -195,7 +218,8 @@ class DeviceMessageHandler:
         )
 
         log_kv({"remote_messages": remote_messages})
-        for destination in remote_messages.keys():
-            # Enqueue a new federation transaction to send the new
-            # device messages to each remote destination.
-            self.federation.send_device_messages(destination)
+        if self.federation_sender:
+            for destination in remote_messages.keys():
+                # Enqueue a new federation transaction to send the new
+                # device messages to each remote destination.
+                self.federation_sender.send_device_messages(destination)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 709f8dfc13..88097639ef 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
 from urllib.parse import urlencode
 
 import attr
@@ -35,7 +35,6 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -85,12 +84,15 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler(BaseHandler):
+class OidcHandler:
     """Handles requests related to the OpenID Connect login flow.
     """
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._store = hs.get_datastore()
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
@@ -116,13 +118,17 @@ class OidcHandler(BaseHandler):
 
         self._http_client = hs.get_proxied_http_client()
         self._server_name = hs.config.server_name  # type: str
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = "oidc"
+
+        # user-facing name of this auth provider
+        self.idp_name = "OIDC"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -475,7 +481,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -499,7 +505,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -511,11 +517,16 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
-        cookie = self._generate_oidc_session_token(
+        if not client_redirect_url:
+            client_redirect_url = b""
+
+        cookie = self._token_generator.generate_oidc_session_token(
             state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url.decode(),
-            ui_auth_session_id=ui_auth_session_id,
+            session_data=OidcSessionData(
+                nonce=nonce,
+                client_redirect_url=client_redirect_url.decode(),
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
@@ -620,11 +631,9 @@ class OidcHandler(BaseHandler):
 
         # Deserialize the session token and verify it.
         try:
-            (
-                nonce,
-                client_redirect_url,
-                ui_auth_session_id,
-            ) = self._verify_oidc_session_token(session, state)
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
             self._sso_handler.render_error(request, "invalid_session", str(e))
@@ -666,14 +675,14 @@ class OidcHandler(BaseHandler):
         else:
             logger.debug("Extracting userinfo from id_token")
             try:
-                userinfo = await self._parse_id_token(token, nonce=nonce)
+                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # first check if we're doing a UIA
-        if ui_auth_session_id:
+        if session_data.ui_auth_session_id:
             try:
                 remote_user_id = self._remote_id_from_userinfo(userinfo)
             except Exception as e:
@@ -682,7 +691,7 @@ class OidcHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+                self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
             )
 
         # otherwise, it's a login
@@ -690,133 +699,12 @@ class OidcHandler(BaseHandler):
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
 
-    def _generate_oidc_session_token(
-        self,
-        state: str,
-        nonce: str,
-        client_redirect_url: str,
-        ui_auth_session_id: Optional[str],
-        duration_in_ms: int = (60 * 60 * 1000),
-    ) -> str:
-        """Generates a signed token storing data about an OIDC session.
-
-        When Synapse initiates an authorization flow, it creates a random state
-        and a random nonce. Those parameters are given to the provider and
-        should be verified when the client comes back from the provider.
-        It is also used to store the client_redirect_url, which is used to
-        complete the SSO login flow.
-
-        Args:
-            state: The ``state`` parameter passed to the OIDC provider.
-            nonce: The ``nonce`` parameter passed to the OIDC provider.
-            client_redirect_url: The URL the client gave when it initiated the
-                flow.
-            ui_auth_session_id: The session ID of the ongoing UI Auth (or
-                None if this is a login).
-            duration_in_ms: An optional duration for the token in milliseconds.
-                Defaults to an hour.
-
-        Returns:
-            A signed macaroon token with the session information.
-        """
-        macaroon = pymacaroons.Macaroon(
-            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("type = session")
-        macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
-        macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (client_redirect_url,)
-        )
-        if ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (ui_auth_session_id,)
-            )
-        now = self.clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
-    def _verify_oidc_session_token(
-        self, session: bytes, state: str
-    ) -> Tuple[str, str, Optional[str]]:
-        """Verifies and extract an OIDC session token.
-
-        This verifies that a given session token was issued by this homeserver
-        and extract the nonce and client_redirect_url caveats.
-
-        Args:
-            session: The session token to verify
-            state: The state the OIDC provider gave back
-
-        Returns:
-            The nonce, client_redirect_url, and ui_auth_session_id for this session
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(session)
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = session")
-        v.satisfy_exact("state = %s" % (state,))
-        v.satisfy_general(lambda c: c.startswith("nonce = "))
-        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
-        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
-
-        v.verify(macaroon, self._macaroon_secret_key)
-
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
-        return nonce, client_redirect_url, ui_auth_session_id
-
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            Exception: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.clock.time_msec()
-        return now < expiry
-
     async def _complete_oidc_login(
         self,
         userinfo: UserInfo,
@@ -893,8 +781,8 @@ class OidcHandler(BaseHandler):
                 # and attempt to match it.
                 attributes = await oidc_response_to_user_attributes(failures=0)
 
-                user_id = UserID(attributes.localpart, self.server_name).to_string()
-                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                user_id = UserID(attributes.localpart, self._server_name).to_string()
+                users = await self._store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     # If an existing matrix ID is returned, then use it.
                     if len(users) == 1:
@@ -923,7 +811,7 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
@@ -946,6 +834,148 @@ class OidcHandler(BaseHandler):
         return str(remote_user_id)
 
 
+class OidcSessionTokenGenerator:
+    """Methods for generating and checking OIDC Session cookies."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
+        self._server_name = hs.hostname
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
+
+    def generate_oidc_session_token(
+        self,
+        state: str,
+        session_data: "OidcSessionData",
+        duration_in_ms: int = (60 * 60 * 1000),
+    ) -> str:
+        """Generates a signed token storing data about an OIDC session.
+
+        When Synapse initiates an authorization flow, it creates a random state
+        and a random nonce. Those parameters are given to the provider and
+        should be verified when the client comes back from the provider.
+        It is also used to store the client_redirect_url, which is used to
+        complete the SSO login flow.
+
+        Args:
+            state: The ``state`` parameter passed to the OIDC provider.
+            session_data: data to include in the session token.
+            duration_in_ms: An optional duration for the token in milliseconds.
+                Defaults to an hour.
+
+        Returns:
+            A signed macaroon token with the session information.
+        """
+        macaroon = pymacaroons.Macaroon(
+            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+        )
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = session")
+        macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
+        macaroon.add_first_party_caveat(
+            "client_redirect_url = %s" % (session_data.client_redirect_url,)
+        )
+        if session_data.ui_auth_session_id:
+            macaroon.add_first_party_caveat(
+                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+            )
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+        return macaroon.serialize()
+
+    def verify_oidc_session_token(
+        self, session: bytes, state: str
+    ) -> "OidcSessionData":
+        """Verifies and extract an OIDC session token.
+
+        This verifies that a given session token was issued by this homeserver
+        and extract the nonce and client_redirect_url caveats.
+
+        Args:
+            session: The session token to verify
+            state: The state the OIDC provider gave back
+
+        Returns:
+            The data extracted from the session cookie
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(session)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = session")
+        v.satisfy_exact("state = %s" % (state,))
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+        # to always satisfy this.
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+        v.satisfy_general(self._verify_expiry)
+
+        v.verify(macaroon, self._macaroon_secret_key)
+
+        # Extract the `nonce`, `client_redirect_url`, and maybe the
+        # `ui_auth_session_id` from the token.
+        nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        client_redirect_url = self._get_value_from_macaroon(
+            macaroon, "client_redirect_url"
+        )
+        try:
+            ui_auth_session_id = self._get_value_from_macaroon(
+                macaroon, "ui_auth_session_id"
+            )  # type: Optional[str]
+        except ValueError:
+            ui_auth_session_id = None
+
+        return OidcSessionData(
+            nonce=nonce,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
+
+    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+        """Extracts a caveat value from a macaroon token.
+
+        Args:
+            macaroon: the token
+            key: the key of the caveat to extract
+
+        Returns:
+            The extracted value
+
+        Raises:
+            Exception: if the caveat was not in the macaroon
+        """
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+    def _verify_expiry(self, caveat: str) -> bool:
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        now = self._clock.time_msec()
+        return now < expiry
+
+
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
+
+    # The `nonce` parameter passed to the OIDC provider.
+    nonce = attr.ib(type=str)
+
+    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+    client_redirect_url = attr.ib(type=str)
+
+    # The session ID of the ongoing UI Auth (None if this is a login)
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+
+
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dee0ef45e7..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["displayname"]
+            return result.get("displayname")
 
     async def set_displayname(
         self,
@@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["avatar_url"]
+            return result.get("avatar_url")
 
     async def set_avatar_url(
         self,
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
+        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        if new_avatar_url == "":
+            avatar_url_to_set = None
+
         # Same like set_displayname
         if by_admin:
             requester = create_requester(
                 target_user, authenticated_entity=requester.authenticated_entity
             )
 
-        await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+        await self.store.set_profile_avatar_url(
+            target_user.localpart, avatar_url_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1f809fa161..3bece6d668 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler):
         creation_content = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }
+        }  # type: JsonDict
 
         # Check if old room was non-federatable
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5fa7ab3f8b..a8376543c9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -73,27 +73,41 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
+
+        # user-facing name of this auth provider
+        self.idp_name = "SAML"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -210,7 +224,7 @@ class SamlHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id,
+                self.idp_id,
                 remote_user_id,
                 current_session.ui_auth_session_id,
                 request,
@@ -306,7 +320,7 @@ class SamlHandler(BaseHandler):
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 33cd6bc178..d096e0b091 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from urllib.parse import urlencode
 
 import attr
-from typing_extensions import NoReturn
+from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException, SynapseError
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -40,6 +43,58 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @property
+    @abc.abstractmethod
+    def idp_name(self) -> str:
+        """User-facing name for this provider"""
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
     # the localpart of the mxid that the mapper has assigned to the user.
@@ -100,6 +155,49 @@ class SsoHandler:
         # a map from session id to session data
         self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
+    def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+        """Get the configured identity providers"""
+        return self._identity_providers
+
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
+
     def render_error(
         self,
         request: Request,
@@ -124,6 +222,34 @@ class SsoHandler:
         )
         respond_with_html(request, code, html)
 
+    async def handle_redirect_request(
+        self, request: SynapseRequest, client_redirect_url: bytes,
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if we only have one auth provider, redirect to it directly
+        if len(self._identity_providers) == 1:
+            ap = next(iter(self._identity_providers.values()))
+            return await ap.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, redirect to the IDP picker
+        return "/_synapse/client/pick_idp?" + urlencode(
+            (("redirectUrl", client_redirect_url),)
+        )
+
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
     ) -> Optional[str]:
@@ -268,7 +394,7 @@ class SsoHandler:
                     attributes,
                     auth_provider_id,
                     remote_user_id,
-                    request.get_user_agent(""),
+                    get_request_user_agent(request),
                     request.getClientIP(),
                 )
 
@@ -534,7 +660,7 @@ class SsoHandler:
             attributes,
             session.auth_provider_id,
             session.remote_user_id,
-            request.get_user_agent(""),
+            get_request_user_agent(request),
             request.getClientIP(),
         )
 
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
 """
 
 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS  # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+    """Constants for use with AuthHandler.set_session_data"""
+
+    # used during registration and password reset to store a hashed copy of the
+    # password, so that the client does not need to submit it each time.
+    PASSWORD_HASH = "password_hash"
+
+    # used during registration to store the mxid of the registered user
+    REGISTERED_USER_ID = "registered_user_id"
+
+    # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+    # for.
+    REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 59b01b812c..4bc3cb53f0 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -17,6 +17,7 @@ import re
 
 from twisted.internet import task
 from twisted.web.client import FileBodyProducer
+from twisted.web.iweb import IRequest
 
 from synapse.api.errors import SynapseError
 
@@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
             pass
+
+
+def get_request_user_agent(request: IRequest, default: str = "") -> str:
+    """Return the last User-Agent header, or the given default.
+    """
+    # There could be raw utf-8 bytes in the User-Agent header.
+
+    # N.B. if you don't do this, the logger explodes cryptically
+    # with maximum recursion trying to log errors about
+    # the charset problem.
+    # c.f. https://github.com/matrix-org/synapse/issues/3471
+
+    h = request.getHeader(b"User-Agent")
+    return h.decode("ascii", "replace") if h else default
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b261e078c4..b7103d6541 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -174,6 +174,16 @@ async def _handle_json_response(
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
+    except ValueError as e:
+        # The JSON content was invalid.
+        logger.warning(
+            "{%s} [%s] Failed to parse JSON response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=False) from e
     except defer.TimeoutError as e:
         logger.warning(
             "{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5a5790831b..12ec3f851f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
-from synapse.http import redact_uri
+from synapse.http import get_request_user_agent, redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.types import Requester
@@ -113,15 +113,6 @@ class SynapseRequest(Request):
             method = self.method.decode("ascii")
         return method
 
-    def get_user_agent(self, default: str) -> str:
-        """Return the last User-Agent header, or the given default.
-        """
-        user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
-        if user_agent is None:
-            return default
-
-        return user_agent.decode("ascii", "replace")
-
     def render(self, resrc):
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
@@ -292,12 +283,7 @@ class SynapseRequest(Request):
             # and can see that we're doing something wrong.
             authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
-        # ...or could be raw utf-8 bytes in the User-Agent header.
-        # N.B. if you don't do this, the logger explodes cryptically
-        # with maximum recursion trying to log errors about
-        # the charset problem.
-        # c.f. https://github.com/matrix-org/synapse/issues/3471
-        user_agent = self.get_user_agent("-")
+        user_agent = get_request_user_agent(self, "-")
 
         code = str(self.code)
         if not self.finished:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index a507a83e93..c2db8b45f3 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -252,7 +252,12 @@ class LoggingContext:
         "scope",
     ]
 
-    def __init__(self, name=None, parent_context=None, request=None) -> None:
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        parent_context: "Optional[LoggingContext]" = None,
+        request: Optional[str] = None,
+    ) -> None:
         self.previous_context = current_context()
         self.name = name
 
@@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
     def __init__(self, request: str = ""):
         self._default_request = request
 
-    def filter(self, record) -> Literal[True]:
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
         context = current_context()
-        record.request = self._default_request
+        record.request = self._default_request  # type: ignore
 
         # context should never be None, but if it somehow ends up being, then
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
             # Logging is interested in the request.
-            record.request = context.request
+            record.request = context.request  # type: ignore
 
         return True
 
@@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     return current
 
 
-def nested_logging_context(
-    suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -632,20 +635,23 @@ def nested_logging_context(
             # ... do stuff
 
     Args:
-        suffix (str): suffix to add to the parent context's 'request'.
-        parent_context (LoggingContext|None): parent context. Will use the current context
-            if None.
+        suffix: suffix to add to the parent context's 'request'.
 
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is not None:
-        context = parent_context  # type: LoggingContextOrSentinel
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Starting nested logging context from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+        prefix = ""
     else:
-        context = current_context()
-    return LoggingContext(
-        parent_context=context, request=str(context.request) + "-" + suffix
-    )
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
+        prefix = str(parent_context.request)
+    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
 
 
 def preserve_fn(f):
@@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = current_context()
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+    else:
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
 
     def g():
-        with LoggingContext(parent_context=logcontext):
+        with LoggingContext(parent_context=parent_context):
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c4c8bb271d..0745899b48 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -396,31 +396,30 @@ class Notifier:
 
         Will wake up all listeners for the given users and rooms.
         """
-        with PreserveLoggingContext():
-            with Measure(self.clock, "on_new_event"):
-                user_streams = set()
+        with Measure(self.clock, "on_new_event"):
+            user_streams = set()
 
-                for user in users:
-                    user_stream = self.user_to_user_stream.get(str(user))
-                    if user_stream is not None:
-                        user_streams.add(user_stream)
+            for user in users:
+                user_stream = self.user_to_user_stream.get(str(user))
+                if user_stream is not None:
+                    user_streams.add(user_stream)
 
-                for room in rooms:
-                    user_streams |= self.room_to_user_streams.get(room, set())
+            for room in rooms:
+                user_streams |= self.room_to_user_streams.get(room, set())
 
-                time_now_ms = self.clock.time_msec()
-                for user_stream in user_streams:
-                    try:
-                        user_stream.notify(stream_key, new_token, time_now_ms)
-                    except Exception:
-                        logger.exception("Failed to notify listener")
+            time_now_ms = self.clock.time_msec()
+            for user_stream in user_streams:
+                try:
+                    user_stream.notify(stream_key, new_token, time_now_ms)
+                except Exception:
+                    logger.exception("Failed to notify listener")
 
-                self.notify_replication()
+            self.notify_replication()
 
-                # Notify appservices
-                self._notify_app_services_ephemeral(
-                    stream_key, new_token, users,
-                )
+            # Notify appservices
+            self._notify_app_services_ephemeral(
+                stream_key, new_token, users,
+            )
 
     def on_new_replication_data(self) -> None:
         """Used to inform replication listeners that something has happened
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 10f27e4378..9018f9e20b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -203,14 +203,18 @@ class BulkPushRuleEvaluator:
 
         condition_cache = {}  # type: Dict[str, bool]
 
+        # If the event is not a state event check if any users ignore the sender.
+        if not event.is_state():
+            ignorers = await self.store.ignored_by(event.sender)
+        else:
+            ignorers = set()
+
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
 
-            if not event.is_state():
-                is_ignored = await self.store.is_ignored_by(event.sender, uid)
-                if is_ignored:
-                    continue
+            if uid in ignorers:
+                continue
 
             display_name = None
             profile_info = room_members.get(uid)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..1260f6d141 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,46 +14,8 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-        self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_inbox", "stream_id"
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(instance_name, token)
-            for row in rows:
-                if row.entity.startswith("@"):
-                    self._device_inbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-                else:
-                    self._device_federation_outbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..1f89249475 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import (
     EventsStream,
     FederationStream,
     Stream,
+    ToDeviceStream,
     TypingStream,
 )
 
@@ -115,6 +116,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, ToDeviceStream):
+                # Only add ToDeviceStream as a source on instances in charge of
+                # sending to device messages.
+                if hs.get_instance_name() in hs.config.worker.writers.to_device:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             if isinstance(stream, TypingStream):
                 # Only add TypingStream as a source on the instance in charge of
                 # typing.
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
new file mode 100644
index 0000000000..f53c9cd679
--- /dev/null
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -0,0 +1,28 @@
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
+        <title>{{server_name | e}} Login</title>
+    </head>
+    <body>
+        <div id="container">
+            <h1 id="title">{{server_name | e}} Login</h1>
+            <div class="login_flow">
+                <p>Choose one of the following identity providers:</p>
+            <form>
+                <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+                <ul class="radiobuttons">
+{% for p in providers %}
+                    <li>
+                        <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
+                        <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+                    </li>
+{% endfor %}
+                </ul>
+                <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+            </form>
+            </div>
+        </div>
+    </body>
+</html>
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 6658c2da56..f39e3d6d5c 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet):
 
                 if deactivate and not user["deactivated"]:
                     await self.deactivate_account_handler.deactivate_account(
-                        target_user.to_string(), False
+                        target_user.to_string(), False, requester, by_admin=True
                     )
                 elif not deactivate and user["deactivated"]:
                     if "password" not in body:
@@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+
+    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(400, "Can only deactivate local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
 
-    async def on_POST(self, request, target_user_id):
-        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet):
                 Codes.BAD_JSON,
             )
 
-        UserID.from_string(target_user_id)
-
         result = await self._deactivate_account_handler.deactivate_account(
-            target_user_id, erase
+            target_user_id, erase, requester, by_admin=True
         )
         if result:
             id_server_unbind_result = "success"
@@ -714,13 +722,6 @@ class UserMembershipRestServlet(RestServlet):
     async def on_GET(self, request, user_id):
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only lookup local users")
-
-        user = await self.store.get_user_by_id(user_id)
-        if user is None:
-            raise NotFoundError("Unknown user")
-
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
         return 200, ret
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5f4c6703db..be938df962 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
+class SsoRedirectServlet(RestServlet):
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        if hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        if hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+
     async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d837bde1d6..65e68d641b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
             password_hash = await self.auth_handler.hash(new_password)
         elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
         else:
             # UI validation was skipped, but the request did not include a new
@@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6b5a1b7109..b093183e79 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..e5b720bbca
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    finish_request,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+    """IdP picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+    which prompts the user to choose an Identity Provider from the list.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._sso_login_idp_picker_template = (
+            hs.config.sso.sso_login_idp_picker_template
+        )
+        self._server_name = hs.hostname
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        idp = parse_string(request, "idp", required=False)
+
+        # if we need to pick an IdP, do so
+        if not idp:
+            return await self._serve_id_picker(request, client_redirect_url)
+
+        # otherwise, redirect to the IdP's redirect URI
+        providers = self._sso_handler.get_identity_providers()
+        auth_provider = providers.get(idp)
+        if not auth_provider:
+            logger.info("Unknown idp %r", idp)
+            self._sso_handler.render_error(
+                request, "unknown_idp", "Unknown identity provider ID"
+            )
+            return
+
+        sso_url = await auth_provider.handle_redirect_request(
+            request, client_redirect_url.encode("utf8")
+        )
+        logger.info("Redirecting to %s", sso_url)
+        request.redirect(sso_url)
+        finish_request(request)
+
+    async def _serve_id_picker(
+        self, request: SynapseRequest, client_redirect_url: str
+    ) -> None:
+        # otherwise, serve up the IdP picker
+        providers = self._sso_handler.get_identity_providers()
+        html = self._sso_login_idp_picker_template.render(
+            redirect_url=client_redirect_url,
+            server_name=self._server_name,
+            providers=providers.values(),
+        )
+        respond_with_html(request, 200, html)
diff --git a/synapse/server.py b/synapse/server.py
index a198b0eb46..d4c235cda5 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -283,10 +283,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         return self._reactor
 
-    def get_ip_from_request(self, request) -> str:
-        # X-Forwarded-For is handled by our custom request type.
-        return request.getClientIP()
-
     def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
         return domain_specific_string.domain == self.hostname
 
@@ -505,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return InitialSyncHandler(self)
 
     @cache_in_self
-    def get_profile_handler(self):
+    def get_profile_handler(self) -> ProfileHandler:
         return ProfileHandler(self)
 
     @cache_in_self
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 83e4f6abc8..dd76714a92 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,6 +31,11 @@ form {
     margin: 10px 0 0 0;
 }
 
+ul.radiobuttons {
+    text-align: left;
+    list-style: none;
+}
+
 /*
  * Add some padding to the viewport.
  */
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..6cfadc2b4e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
-    LoggingContextOrSentinel,
     current_context,
     make_deferred_yieldable,
 )
@@ -180,6 +179,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -267,6 +269,20 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
+
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -277,7 +293,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,9 +364,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.
 
@@ -671,12 +684,15 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
-        if not parent_context:
+        curr_context = current_context()
+        if not curr_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
 
         start_time = monotonic_time()
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 701748f93b..c4de07a0a8 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -127,9 +127,6 @@ class DataStore(
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
-        self._device_inbox_id_gen = StreamIdGenerator(
-            db_conn, "device_inbox", "stream_id"
-        )
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
@@ -189,36 +186,6 @@ class DataStore(
             prefilled_cache=presence_cache_prefill,
         )
 
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_inbox",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            min_device_inbox_id,
-            prefilled_cache=device_inbox_prefill,
-        )
-        # The federation outbox and the local device inbox uses the same
-        # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_federation_outbox",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            min_device_outbox_id,
-            prefilled_cache=device_outbox_prefill,
-        )
-
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 49ee23470d..bad8260892 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -24,7 +24,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import _CacheContext, cached
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -287,35 +287,34 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
-    @cached(num_args=2, cache_context=True, max_entries=5000)
-    async def is_ignored_by(
-        self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
-    ) -> bool:
-        ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            AccountDataTypes.IGNORED_USER_LIST,
-            ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            return False
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_by(self, user_id: str) -> Set[str]:
+        """
+        Get users which ignore the given user.
 
-        try:
-            return ignored_user_id in ignored_account_data.get("ignored_users", {})
-        except TypeError:
-            # The type of the ignored_users field is invalid.
-            return False
+        Params:
+            user_id: The user ID which might be ignored.
+
+        Return:
+            The user IDs which ignore the given user.
+        """
+        return set(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignored_user_id": user_id},
+                retcol="ignorer_user_id",
+                desc="ignored_by",
+            )
+        )
 
 
 class AccountDataStore(AccountDataWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn,
-            "account_data_max_stream_id",
+            "room_account_data",
             "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
+            extra_tables=[("room_tags_revisions", "stream_id")],
         )
 
         super().__init__(database, db_conn, hs)
@@ -360,14 +359,6 @@ class AccountDataStore(AccountDataWorkerStore):
                 lock=False,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
@@ -390,32 +381,16 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
-        content_json = json_encoder.encode(content)
-
         async with self._account_data_id_gen.get_next() as next_id:
-            # no need to lock here as account_data has a unique constraint on
-            # (user_id, account_data_type) so simple_upsert will retry if
-            # there is a conflict.
-            await self.db_pool.simple_upsert(
-                desc="add_user_account_data",
-                table="account_data",
-                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
-                values={"stream_id": next_id, "content": content_json},
-                lock=False,
+            await self.db_pool.runInteraction(
+                "add_user_account_data",
+                self._add_account_data_for_user,
+                next_id,
+                user_id,
+                account_data_type,
+                content,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            #
-            # Note: This is only here for backwards compat to allow admins to
-            # roll back to a previous Synapse version. Next time we update the
-            # database version we can remove this table.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
@@ -424,23 +399,67 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    async def _update_max_stream_id(self, next_id: int) -> None:
-        """Update the max stream_id
+    def _add_account_data_for_user(
+        self,
+        txn,
+        next_id: int,
+        user_id: str,
+        account_data_type: str,
+        content: JsonDict,
+    ) -> None:
+        content_json = json_encoder.encode(content)
 
-        Args:
-            next_id: The the revision to advance to.
-        """
+        # no need to lock here as account_data has a unique constraint on
+        # (user_id, account_data_type) so simple_upsert will retry if
+        # there is a conflict.
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="account_data",
+            keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+            values={"stream_id": next_id, "content": content_json},
+            lock=False,
+        )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
+        # Ignored users get denormalized into a separate table as an optimisation.
+        if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
+            return
 
-        def _update(txn):
-            update_max_id_sql = (
-                "UPDATE account_data_max_stream_id"
-                " SET stream_id = ?"
-                " WHERE stream_id < ?"
+        # Insert / delete to sync the list of ignored users.
+        previously_ignored_users = set(
+            self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
             )
-            txn.execute(update_max_id_sql, (next_id, next_id))
+        )
+
+        # If the data is invalid, no one is ignored.
+        ignored_users_content = content.get("ignored_users", {})
+        if isinstance(ignored_users_content, dict):
+            currently_ignored_users = set(ignored_users_content)
+        else:
+            currently_ignored_users = set()
+
+        # Delete entries which are no longer ignored.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ignored_users",
+            column="ignored_user_id",
+            iterable=previously_ignored_users - currently_ignored_users,
+            keyvalues={"ignorer_user_id": user_id},
+        )
+
+        # Add entries which are newly ignored.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="ignored_users",
+            values=[
+                {"ignorer_user_id": user_id, "ignored_user_id": u}
+                for u in currently_ignored_users - previously_ignored_users
+            ],
+        )
 
-        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        # Invalidate the cache for any ignored users which were added or removed.
+        for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
+            self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index e96a8b3f43..c53c836337 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -470,43 +470,35 @@ class ClientIpStore(ClientIpWorkerStore):
         for entry in to_update.items():
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
-            try:
-                self.db_pool.simple_upsert_txn(
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="user_ips",
+                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
+                values={
+                    "user_agent": user_agent,
+                    "device_id": device_id,
+                    "last_seen": last_seen,
+                },
+                lock=False,
+            )
+
+            # Technically an access token might not be associated with
+            # a device so we need to check.
+            if device_id:
+                # this is always an update rather than an upsert: the row should
+                # already exist, and if it doesn't, that may be because it has been
+                # deleted, and we don't want to re-create it.
+                self.db_pool.simple_update_txn(
                     txn,
-                    table="user_ips",
-                    keyvalues={
-                        "user_id": user_id,
-                        "access_token": access_token,
-                        "ip": ip,
-                    },
-                    values={
+                    table="devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    updatevalues={
                         "user_agent": user_agent,
-                        "device_id": device_id,
                         "last_seen": last_seen,
+                        "ip": ip,
                     },
-                    lock=False,
                 )
 
-                # Technically an access token might not be associated with
-                # a device so we need to check.
-                if device_id:
-                    # this is always an update rather than an upsert: the row should
-                    # already exist, and if it doesn't, that may be because it has been
-                    # deleted, and we don't want to re-create it.
-                    self.db_pool.simple_update_txn(
-                        txn,
-                        table="devices",
-                        keyvalues={"user_id": user_id, "device_id": device_id},
-                        updatevalues={
-                            "user_agent": user_agent,
-                            "last_seen": last_seen,
-                            "ip": ip,
-                        },
-                    )
-            except Exception as e:
-                # Failed to upsert, log and continue
-                logger.error("Failed to insert client IP %r: %r", entry, e)
-
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
     ) -> Dict[Tuple[str, str], dict]:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index d42faa3f1f..58d3f71e45 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -17,15 +17,100 @@ import logging
 from typing import List, Tuple
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_device = (
+                self._instance_name in hs.config.worker.writers.to_device
+            )
+
+            self._device_inbox_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="to_device",
+                instance_name=self._instance_name,
+                table="device_inbox",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="device_inbox_sequence",
+                writers=hs.config.worker.writers.to_device,
+            )
+        else:
+            self._can_write_to_device = True
+            self._device_inbox_id_gen = StreamIdGenerator(
+                db_conn, "device_inbox", "stream_id"
+            )
+
+        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_inbox",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_inbox_stream_cache = StreamChangeCache(
+            "DeviceInboxStreamChangeCache",
+            min_device_inbox_id,
+            prefilled_cache=device_inbox_prefill,
+        )
+
+        # The federation outbox and the local device inbox uses the same
+        # stream_id generator.
+        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_federation_outbox",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            min_device_outbox_id,
+            prefilled_cache=device_outbox_prefill,
+        )
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
+            for row in rows:
+                if row.entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
@@ -278,52 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
-
-class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            "device_inbox_stream_index",
-            index_name="device_inbox_stream_id_user_id",
-            table="device_inbox",
-            columns=["stream_id", "user_id"],
-        )
-
-        self.db_pool.updates.register_background_update_handler(
-            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
-        )
-
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
-            txn = conn.cursor()
-            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
-            txn.close()
-
-        await self.db_pool.runWithConnection(reindex_txn)
-
-        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
-        return 1
-
-
-class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        # Map of (user_id, device_id) to the last stream_id that has been
-        # deleted up to. This is so that we can no op deletions.
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
     @trace
     async def add_messages_to_device_inbox(
         self,
@@ -342,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             The new stream_id.
         """
 
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
@@ -351,16 +392,20 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             # Add the remote messages to the federation outbox.
             # We'll send them to a remote server when we next send a
             # federation transaction to that destination.
-            sql = (
-                "INSERT INTO device_federation_outbox"
-                " (destination, stream_id, queued_ts, messages_json)"
-                " VALUES (?,?,?,?)"
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="device_federation_outbox",
+                values=[
+                    {
+                        "destination": destination,
+                        "stream_id": stream_id,
+                        "queued_ts": now_ms,
+                        "messages_json": json_encoder.encode(edu),
+                        "instance_name": self._instance_name,
+                    }
+                    for destination, edu in remote_messages_by_destination.items()
+                ],
             )
-            rows = []
-            for destination, edu in remote_messages_by_destination.items():
-                edu_json = json_encoder.encode(edu)
-                rows.append((destination, stream_id, now_ms, edu_json))
-            txn.executemany(sql, rows)
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
@@ -379,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     async def add_messages_from_remote_to_device_inbox(
         self, origin: str, message_id: str, local_messages_by_user_then_device: dict
     ) -> int:
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
@@ -427,38 +474,45 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     def _add_messages_to_local_device_inbox_txn(
         self, txn, stream_id, messages_by_user_then_device
     ):
+        assert self._can_write_to_device
+
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
             messages_json_for_user = {}
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices WHERE user_id = ?"
-                txn.execute(sql, (user_id,))
+                devices = self.db_pool.simple_select_onecol_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    retcol="device_id",
+                )
+
                 message_json = json_encoder.encode(messages_by_device["*"])
-                for row in txn:
+                for device_id in devices:
                     # Add the message for all devices for this user on this
                     # server.
-                    device = row[0]
-                    messages_json_for_user[device] = message_json
+                    messages_json_for_user[device_id] = message_json
             else:
                 if not devices:
                     continue
 
-                clause, args = make_in_list_sql_clause(
-                    txn.database_engine, "device_id", devices
+                rows = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    column="device_id",
+                    iterable=devices,
+                    retcols=("device_id",),
                 )
-                sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
 
-                # TODO: Maybe this needs to be done in batches if there are
-                # too many local devices for a given user.
-                txn.execute(sql, [user_id] + list(args))
-                for row in txn:
+                for row in rows:
                     # Only insert into the local inbox if the device exists on
                     # this server
-                    device = row[0]
-                    message_json = json_encoder.encode(messages_by_device[device])
-                    messages_json_for_user[device] = message_json
+                    device_id = row["device_id"]
+                    message_json = json_encoder.encode(messages_by_device[device_id])
+                    messages_json_for_user[device_id] = message_json
 
             if messages_json_for_user:
                 local_by_user_then_device[user_id] = messages_json_for_user
@@ -466,14 +520,52 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
         if not local_by_user_then_device:
             return
 
-        sql = (
-            "INSERT INTO device_inbox"
-            " (user_id, device_id, stream_id, message_json)"
-            " VALUES (?,?,?,?)"
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_inbox",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "message_json": message_json,
+                    "instance_name": self._instance_name,
+                }
+                for user_id, messages_by_device in local_by_user_then_device.items()
+                for device_id, message_json in messages_by_device.items()
+            ],
         )
-        rows = []
-        for user_id, messages_by_device in local_by_user_then_device.items():
-            for device_id, message_json in messages_by_device.items():
-                rows.append((user_id, device_id, stream_id, message_json))
 
-        txn.executemany(sql, rows)
+
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            "device_inbox_stream_index",
+            index_name="device_inbox_stream_id_user_id",
+            table="device_inbox",
+            columns=["stream_id", "user_id"],
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
+        )
+
+    async def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+            txn.close()
+
+        await self.db_pool.runWithConnection(reindex_txn)
+
+        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+    pass
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4d1b92d1aa..1b6ccd51c8 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -707,50 +707,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         """Get the current stream id from the _device_list_id_gen"""
         ...
 
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    async def set_e2e_device_keys(
-        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
-    ) -> bool:
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-
-        def _set_e2e_device_keys_txn(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
-
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
-
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
-        )
-
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
     ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
@@ -840,6 +796,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+
+        def _set_e2e_device_keys_txn(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("time_now", time_now)
+            set_tag("device_keys", device_keys)
+
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                log_kv({"Message": "Device key already stored."})
+                return False
+
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"ts_added_ms": time_now, "key_json": new_key_json},
+            )
+            log_kv({"message": "Device keys stored."})
+            return True
+
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        )
+
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ebffd89251..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
 logger = logging.getLogger(__name__)
 
 
+class _NoChainCoverIndex(Exception):
+    def __init__(self, room_id: str):
+        super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             The set of the difference in auth chains.
         """
 
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_difference_chains",
+                    self._get_auth_chain_difference_using_cover_index_txn,
+                    room_id,
+                    state_sets,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
         )
 
+    def _get_auth_chain_difference_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
+        """Calculates the auth chain difference using the chain index.
+
+        See docs/auth_chain_difference_algorithm.md for details
+        """
+
+        # First we look up the chain ID/sequence numbers for all the events, and
+        # work out the chain/sequence numbers reachable from each state set.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Map from event_id -> (chain ID, seq no)
+        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Map from chain ID -> seq no -> event Id
+        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+
+        # All the chains that we've found that are reachable from the state
+        # sets.
+        seen_chains = set()  # type: Set[int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                chain_info[event_id] = (chain_id, sequence_number)
+                seen_chains.add(chain_id)
+                chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(chain_info)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Corresponds to `state_sets`, except as a map from chain ID to max
+        # sequence number reachable from the state set.
+        set_to_chain = []  # type: List[Dict[int, int]]
+        for state_set in state_sets:
+            chains = {}  # type: Dict[int, int]
+            set_to_chain.append(chains)
+
+            for event_id in state_set:
+                chain_id, seq_no = chain_info[event_id]
+
+                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+        # Now we look up all links for the chains we have, adding chains to
+        # set_to_chain that are reachable from each set.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # (We need to take a copy of `seen_chains` as we want to mutate it in
+        # the loop)
+        for batch in batch_iter(set(seen_chains), 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                for chains in set_to_chain:
+                    # chains are only reachable if the origin sequence number of
+                    # the link is less than the max sequence number in the
+                    # origin chain.
+                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
+                        chains[target_chain_id] = max(
+                            target_sequence_number, chains.get(target_chain_id, 0),
+                        )
+
+                seen_chains.add(target_chain_id)
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* state set and the minimum sequence number reachable from
+        # *all* state sets. Events in that range are in the auth chain
+        # difference.
+        result = set()
+
+        # Mapping from chain ID to the range of sequence numbers that should be
+        # pulled from the database.
+        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+
+        for chain_id in seen_chains:
+            min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+            max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+            if min_seq_no < max_seq_no:
+                # We have a non empty gap, try and fill it from the events that
+                # we have, otherwise add them to the list of gaps to pull out
+                # from the DB.
+                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+                    if event_id:
+                        result.add(event_id)
+                    else:
+                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+                        break
+
+        if not chain_to_gap:
+            # If there are no gaps to fetch, we're done!
+            return result
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND min_seq < sequence_number AND sequence_number <= max_seq
+            """
+
+            args = [
+                (chain_id, min_no, max_no)
+                for chain_id, (min_no, max_no) in chain_to_gap.items()
+            ]
+
+            rows = txn.execute_values(sql, args)
+            result.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+            """
+            for chain_id, (min_no, max_no) in chain_to_gap.items():
+                txn.execute(sql, (chain_id, min_no, max_no))
+                result.update(r for r, in txn)
+
+        return result
+
     def _get_auth_chain_difference_txn(
         self, txn, state_sets: List[Set[str]]
     ) -> Set[str]:
+        """Calculates the auth chain difference using a breadth first search.
+
+        This is used when we don't have a cover index for the room.
+        """
 
         # Algorithm Description
         # ~~~~~~~~~~~~~~~~~~~~~
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..186f064036 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -89,6 +100,14 @@ class PersistEventsStore:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self._event_chain_id_gen = build_sequence_generator(
+            db.engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
@@ -366,6 +385,36 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
+        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+    def _persist_event_auth_chain_txn(
+        self, txn: LoggingTransaction, events: List[EventBase],
+    ) -> None:
+
+        # We only care about state events, so this if there are no state events.
+        if not any(e.is_state() for e in events):
+            return
+
         # We want to store event_auth mappings for rejected events, as they're
         # used in state res v2.
         # This is only necessary if the rejected event appears in an accepted
@@ -381,31 +430,357 @@ class PersistEventsStore:
                     "room_id": event.room_id,
                     "auth_id": auth_id,
                 }
-                for event, _ in events_and_contexts
+                for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
             ],
         )
 
-        # _store_rejected_events_txn filters out any events which were
-        # rejected, and returns the filtered list.
-        events_and_contexts = self._store_rejected_events_txn(
-            txn, events_and_contexts=events_and_contexts
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="rooms",
+            column="room_id",
+            iterable={event.room_id for event in events if event.is_state()},
+            keyvalues={},
+            retcols=("room_id", "has_auth_chain_index"),
         )
+        rooms_using_chain_index = {
+            row["room_id"] for row in rows if row["has_auth_chain_index"]
+        }
 
-        # From this point onwards the events are only ones that weren't
-        # rejected.
+        state_events = {
+            event.event_id: event
+            for event in events
+            if event.is_state() and event.room_id in rooms_using_chain_index
+        }
 
-        self._update_metadata_tables_txn(
+        if not state_events:
+            return
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {
+            e.event_id: (e.type, e.state_key) for e in state_events.values()
+        }
+        event_to_auth_chain = {
+            e.event_id: e.auth_event_ids() for e in state_events.values()
+        }
+
+        # Set of event IDs to calculate chain ID/seq numbers for.
+        events_to_calc_chain_id_for = set(state_events)
+
+        # We check if there are any events that need to be handled in the rooms
+        # we're looking at. These should just be out of band memberships, where
+        # we didn't have the auth chain when we first persisted.
+        rows = self.db_pool.simple_select_many_txn(
             txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="room_id",
+            iterable={e.room_id for e in state_events.values()},
+            retcols=("event_id", "type", "state_key"),
         )
+        for row in rows:
+            event_id = row["event_id"]
+            event_type = row["type"]
+            state_key = row["state_key"]
+
+            # (We could pull out the auth events for all rows at once using
+            # simple_select_many, but this case happens rarely and almost always
+            # with a single row.)
+            auth_events = self.db_pool.simple_select_onecol_txn(
+                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+            )
 
-        # We call this last as it assumes we've inserted the events into
-        # room_memberships, where applicable.
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+            events_to_calc_chain_id_for.add(event_id)
+            event_to_types[event_id] = (event_type, state_key)
+            event_to_auth_chain[event_id] = auth_events
+
+        # First we get the chain ID and sequence numbers for the events'
+        # auth events (that aren't also currently being persisted).
+        #
+        # Note that there there is an edge case here where we might not have
+        # calculated chains and sequence numbers for events that were "out
+        # of band". We handle this case by fetching the necessary info and
+        # adding it to the set of events to calculate chain IDs for.
+
+        missing_auth_chains = {
+            a_id
+            for auth_events in event_to_auth_chain.values()
+            for a_id in auth_events
+            if a_id not in events_to_calc_chain_id_for
+        }
+
+        # We loop here in case we find an out of band membership and need to
+        # fetch their auth event info.
+        while missing_auth_chains:
+            sql = """
+                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                WHERE
+            """
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", missing_auth_chains,
+            )
+            txn.execute(sql + clause, args)
+
+            missing_auth_chains.clear()
+
+            for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+                event_to_types[auth_id] = (event_type, state_key)
+
+                if chain_id is None:
+                    # No chain ID, so the event was persisted out of band.
+                    # We add to list of events to calculate auth chains for.
+
+                    events_to_calc_chain_id_for.add(auth_id)
+
+                    event_to_auth_chain[
+                        auth_id
+                    ] = self.db_pool.simple_select_onecol_txn(
+                        txn,
+                        "event_auth",
+                        keyvalues={"event_id": auth_id},
+                        retcol="auth_id",
+                    )
+
+                    missing_auth_chains.update(
+                        e
+                        for e in event_to_auth_chain[auth_id]
+                        if e not in event_to_types
+                    )
+                else:
+                    chain_map[auth_id] = (chain_id, sequence_number)
+
+        # Now we check if we have any events where we don't have auth chain,
+        # this should only be out of band memberships.
+        for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+            for auth_id in event_to_auth_chain[event_id]:
+                if (
+                    auth_id not in chain_map
+                    and auth_id not in events_to_calc_chain_id_for
+                ):
+                    events_to_calc_chain_id_for.discard(event_id)
+
+                    # If this is an event we're trying to persist we add it to
+                    # the list of events to calculate chain IDs for next time
+                    # around. (Otherwise we will have already added it to the
+                    # table).
+                    event = state_events.get(event_id)
+                    if event:
+                        self.db_pool.simple_insert_txn(
+                            txn,
+                            table="event_auth_chain_to_calculate",
+                            values={
+                                "event_id": event.event_id,
+                                "room_id": event.room_id,
+                                "type": event.type,
+                                "state_key": event.state_key,
+                            },
+                        )
+
+                    # We stop checking the event's auth events since we've
+                    # discarded it.
+                    break
+
+        if not events_to_calc_chain_id_for:
+            return
+
+        # We now calculate the chain IDs/sequence numbers for the events. We
+        # do this by looking at the chain ID and sequence number of any auth
+        # event with the same type/state_key and incrementing the sequence
+        # number by one. If there was no match or the chain ID/sequence
+        # number is already taken we generate a new chain.
+        #
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events
+        # before the event itself.
+        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            existing_chain_id = None
+            for auth_id in event_to_auth_chain[event_id]:
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map[auth_id]
+                    break
+
+            new_chain_tuple = None
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
+                    already_allocated = self.db_pool.simple_select_one_onecol_txn(
+                        txn,
+                        table="event_auth_chains",
+                        keyvalues={
+                            "chain_id": proposed_new_id,
+                            "sequence_number": proposed_new_seq,
+                        },
+                        retcol="event_id",
+                        allow_none=True,
+                    )
+                    if already_allocated:
+                        # Mark it as already allocated so we don't need to hit
+                        # the DB again.
+                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
+                    else:
+                        new_chain_tuple = (
+                            proposed_new_id,
+                            proposed_new_seq,
+                        )
+
+            if not new_chain_tuple:
+                new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+
+            chains_tuples_allocated.add(new_chain_tuple)
+
+            chain_map[event_id] = new_chain_tuple
+            new_chain_tuples[event_id] = new_chain_tuple
+
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            values=[
+                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                for event_id, (c_id, seq) in new_chain_tuples.items()
+            ],
+        )
+
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            iterable=new_chain_tuples,
+        )
+
+        # Now we need to calculate any new links between chains caused by
+        # the new events.
+        #
+        # Links are pairs of chain ID/sequence numbers such that for any
+        # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+        # if and only if there is at least one link (CA, S1) -> (CB, S2)
+        # where SA >= S1 and S2 >= SB.
+        #
+        # We try and avoid adding redundant links to the table, e.g. if we
+        # have two links between two chains which both start/end at the
+        # sequence number event (or cross) then one can be safely dropped.
+        #
+        # To calculate new links we look at every new event and:
+        #   1. Fetch the chain ID/sequence numbers of its auth events,
+        #      discarding any that are reachable by other auth events, or
+        #      that have the same chain ID as the event.
+        #   2. For each retained auth event we:
+        #       a. Add a link from the event's to the auth event's chain
+        #          ID/sequence number; and
+        #       b. Add a link from the event to every chain reachable by the
+        #          auth event.
+
+        # Step 1, fetch all existing links from all the chains we've seen
+        # referenced.
+        chain_links = _LinkMap()
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            column="origin_chain_id",
+            iterable={chain_id for chain_id, _ in chain_map.values()},
+            keyvalues={},
+            retcols=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
+        )
+        for row in rows:
+            chain_links.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+                new=False,
+            )
+
+        # We do this in toplogical order to avoid adding redundant links.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            chain_id, sequence_number = chain_map[event_id]
+
+            # Filter out auth events that are reachable by other auth
+            # events. We do this by looking at every permutation of pairs of
+            # auth events (A, B) to check if B is reachable from A.
+            reduction = {
+                a_id
+                for a_id in event_to_auth_chain[event_id]
+                if chain_map[a_id][0] != chain_id
+            }
+            for start_auth_id, end_auth_id in itertools.permutations(
+                event_to_auth_chain[event_id], r=2,
+            ):
+                if chain_links.exists_path_from(
+                    chain_map[start_auth_id], chain_map[end_auth_id]
+                ):
+                    reduction.discard(end_auth_id)
+
+            # Step 2, figure out what the new links are from the reduced
+            # list of auth events.
+            for auth_id in reduction:
+                auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+                # Step 2a, add link between the event and auth event
+                chain_links.add_link(
+                    (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+                )
+
+                # Step 2b, add a link to chains reachable from the auth
+                # event.
+                for target_id, target_seq in chain_links.get_links_from(
+                    (auth_chain_id, auth_sequence_number)
+                ):
+                    if target_id == chain_id:
+                        continue
+
+                    chain_links.add_link(
+                        (chain_id, sequence_number), (target_id, target_seq)
+                    )
+
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            values=[
+                {
+                    "origin_chain_id": source_id,
+                    "origin_sequence_number": source_seq,
+                    "target_chain_id": target_id,
+                    "target_sequence_number": target_seq,
+                }
+                for (
+                    source_id,
+                    source_seq,
+                    target_id,
+                    target_seq,
+                ) in chain_links.get_additions()
+            ],
+        )
 
     def _persist_transaction_ids_txn(
         self,
@@ -799,7 +1174,8 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _store_event_txn(self, txn, events_and_contexts):
-        """Insert new events into the event and event_json tables
+        """Insert new events into the event, event_json, redaction and
+        state_events tables.
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
@@ -871,6 +1247,29 @@ class PersistEventsStore:
                     updatevalues={"have_censored": False},
                 )
 
+        state_events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0].is_state()
+        ]
+
+        state_values = []
+        for event, context in state_events_and_contexts:
+            vals = {
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "type": event.type,
+                "state_key": event.state_key,
+            }
+
+            # TODO: How does this work with backfilling?
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
+
+            state_values.append(vals)
+
+        self.db_pool.simple_insert_many_txn(
+            txn, table="state_events", values=state_values
+        )
+
     def _store_rejected_events_txn(self, txn, events_and_contexts):
         """Add rows to the 'rejections' table for received events which were
         rejected
@@ -987,29 +1386,6 @@ class PersistEventsStore:
             txn, [event for event, _ in events_and_contexts]
         )
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, context in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
@@ -1520,3 +1896,131 @@ class PersistEventsStore:
                 if not ev.internal_metadata.is_outlier()
             ],
         )
+
+
+@attr.s(slots=True)
+class _LinkMap:
+    """A helper type for tracking links between chains.
+    """
+
+    # Stores the set of links as nested maps: source chain ID -> target chain ID
+    # -> source sequence number -> target sequence number.
+    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+    # Stores the links that have been added (with new set to true), as tuples of
+    # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+    def add_link(
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
+        new: bool = True,
+    ) -> bool:
+        """Add a new link between two chains, ensuring no redundant links are added.
+
+        New links should be added in topological order.
+
+        Args:
+            src_tuple: The chain ID/sequence number of the source of the link.
+            target_tuple: The chain ID/sequence number of the target of the link.
+            new: Whether this is a "new" link, i.e. should it be returned
+                by `get_additions`.
+
+        Returns:
+            True if a link was added, false if the given link was dropped as redundant
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+        assert src_chain != target_chain
+
+        if new:
+            # Check if the new link is redundant
+            for current_seq_src, current_seq_target in current_links.items():
+                # If a link "crosses" another link then its redundant. For example
+                # in the following link 1 (L1) is redundant, as any event reachable
+                # via L1 is *also* reachable via L2.
+                #
+                #   Chain A     Chain B
+                #      |          |
+                #   L1 |------    |
+                #      |     |    |
+                #   L2 |---- | -->|
+                #      |     |    |
+                #      |     |--->|
+                #      |          |
+                #      |          |
+                #
+                # So we only need to keep links which *do not* cross, i.e. links
+                # that both start and end above or below an existing link.
+                #
+                # Note, since we add links in topological ordering we should never
+                # see `src_seq` less than `current_seq_src`.
+
+                if current_seq_src <= src_seq and target_seq <= current_seq_target:
+                    # This new link is redundant, nothing to do.
+                    return False
+
+            self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+        current_links[src_seq] = target_seq
+        return True
+
+    def get_links_from(
+        self, src_tuple: Tuple[int, int]
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the chains reachable from the given chain/sequence number.
+
+        Yields:
+            The chain ID and sequence number the link points to.
+        """
+        src_chain, src_seq = src_tuple
+        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+            for link_src_seq, target_seq in sequence_numbers.items():
+                if link_src_seq <= src_seq:
+                    yield target_id, target_seq
+
+    def get_links_between(
+        self, source_chain: int, target_chain: int
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the links between two chains.
+
+        Yields:
+            The source and target sequence numbers.
+        """
+
+        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+    def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+        """Gets any newly added links.
+
+        Yields:
+            The source chain ID/sequence number and target chain ID/sequence number
+        """
+
+        for src_chain, src_seq, target_chain, _ in self.additions:
+            target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+            if target_seq is not None:
+                yield (src_chain, src_seq, target_chain, target_seq)
+
+    def exists_path_from(
+        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+    ) -> bool:
+        """Checks if there is a path between the source chain ID/sequence and
+        target chain ID/sequence.
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        if src_chain == target_chain:
+            return target_seq <= src_seq
+
+        links = self.get_links_between(src_chain, target_chain)
+        for link_start_seq, link_end_seq in links:
+            if link_start_seq <= src_seq and target_seq <= link_end_seq:
+                return True
+
+        return False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 97b6754846..7e4b175d08 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,10 +14,15 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from synapse.api.constants import EventContentFields
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -99,6 +104,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             columns=["user_id", "created_ts"],
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "rejected_events_metadata", self._rejected_events_metadata,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -582,3 +591,118 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
+
+    async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
+        """Adds rejected events to the `state_events` and `event_auth` metadata
+        tables.
+        """
+
+        last_event_id = progress.get("last_event_id", "")
+
+        def get_rejected_events(
+            txn: Cursor,
+        ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
+            # Fetch rejected event json, their room version and whether we have
+            # inserted them into the state_events or auth_events tables.
+            #
+            # Note we can assume that events that don't have a corresponding
+            # room version are V1 rooms.
+            sql = """
+                SELECT DISTINCT
+                    event_id,
+                    COALESCE(room_version, '1'),
+                    json,
+                    state_events.event_id IS NOT NULL,
+                    event_auth.event_id IS NOT NULL
+                FROM rejections
+                INNER JOIN event_json USING (event_id)
+                LEFT JOIN rooms USING (room_id)
+                LEFT JOIN state_events USING (event_id)
+                LEFT JOIN event_auth USING (event_id)
+                WHERE event_id > ?
+                ORDER BY event_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_event_id, batch_size,))
+
+            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+
+        results = await self.db_pool.runInteraction(
+            desc="_rejected_events_metadata_get", func=get_rejected_events
+        )
+
+        if not results:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+            return 0
+
+        state_events = []
+        auth_events = []
+        for event_id, room_version, event_json, has_state, has_event_auth in results:
+            last_event_id = event_id
+
+            if has_state and has_event_auth:
+                continue
+
+            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
+            if not room_version_obj:
+                # We no longer support this room version, so we just ignore the
+                # events entirely.
+                logger.info(
+                    "Ignoring event with unknown room version %r: %r",
+                    room_version,
+                    event_id,
+                )
+                continue
+
+            event = make_event_from_dict(event_json, room_version_obj)
+
+            if not event.is_state():
+                continue
+
+            if not has_state:
+                state_events.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+            if not has_event_auth:
+                for auth_id in event.auth_event_ids():
+                    auth_events.append(
+                        {
+                            "room_id": event.room_id,
+                            "event_id": event.event_id,
+                            "auth_id": auth_id,
+                        }
+                    )
+
+        if state_events:
+            await self.db_pool.simple_insert_many(
+                table="state_events",
+                values=state_events,
+                desc="_rejected_events_metadata_state_events",
+            )
+
+        if auth_events:
+            await self.db_pool.simple_insert_many(
+                table="event_auth",
+                values=auth_events,
+                desc="_rejected_events_metadata_event_auth",
+            )
+
+        await self.db_pool.updates._background_update_progress(
+            "rejected_events_metadata", {"last_event_id": last_event_id}
+        )
+
+        if len(results) < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+
+        return len(results)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 0e25ca3d7a..54ef0f1f54 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_avatar_url(
-        self, user_localpart: str, new_avatar_url: str
+        self, user_localpart: str, new_avatar_url: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4650d0689b..284f2ce77c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
         return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator"),
+            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
             desc="get_room",
             allow_none=True,
         )
@@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         # It's overridden by RoomStore for the synapse master.
         raise NotImplementedError()
 
+    async def has_auth_chain_index(self, room_id: str) -> bool:
+        """Check if the room has (or can have) a chain cover index.
+
+        Defaults to True if we don't have an entry in `rooms` table nor any
+        events for the room.
+        """
+
+        has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcol="has_auth_chain_index",
+            desc="has_auth_chain_index",
+            allow_none=True,
+        )
+
+        if has_auth_chain_index:
+            return True
+
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        max_ordering = await self.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={"room_id": room_id},
+            retcol="MAX(stream_ordering)",
+            allow_none=True,
+            desc="upsert_room_on_join",
+        )
+
+        return max_ordering is None
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         Called when we join a room over federation, and overwrites any room version
         currently in the table.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="upsert_room_on_join",
             table="rooms",
             keyvalues={"room_id": room_id},
             values={"room_version": room_version.identifier},
-            insertion_values={"is_public": False, "creator": ""},
+            insertion_values={
+                "is_public": False,
+                "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
+            },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
             lock=False,
@@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         "creator": room_creator_user_id,
                         "is_public": is_public,
                         "room_version": room_version.identifier,
+                        "has_auth_chain_index": True,
                     },
                 )
                 if is_public:
@@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="maybe_store_room_on_outlier_membership",
             table="rooms",
@@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "room_version": room_version.identifier,
                 "is_public": False,
                 "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
             },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
new file mode 100644
index 0000000000..de57645019
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE access_tokens DROP COLUMN last_used;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
new file mode 100644
index 0000000000..ee0e3521bf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Dropping last_used column from access_tokens table.
+
+CREATE TABLE access_tokens2 (
+    id BIGINT PRIMARY KEY, 
+    user_id TEXT NOT NULL, 
+    device_id TEXT, 
+    token TEXT NOT NULL,
+    valid_until_ms BIGINT,
+    puppets_user_id TEXT,
+    last_validated BIGINT,
+    UNIQUE(token) 
+);
+
+INSERT INTO access_tokens2(id, user_id, device_id, token)
+    SELECT id, user_id, device_id, token FROM access_tokens;
+
+DROP TABLE access_tokens;
+ALTER TABLE access_tokens2 RENAME TO access_tokens;
+
+CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
+
+
+-- Re-adding foreign key reference in event_txn_id table
+
+CREATE TABLE event_txn_id2 (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
+    SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
+
+DROP TABLE event_txn_id;
+ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
new file mode 100644
index 0000000000..9c95646281
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
new file mode 100644
index 0000000000..f35c70b699
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -0,0 +1,82 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration denormalises the account_data table into an ignored users table.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage._base import db_to_json
+from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    logger.info("Creating ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_create_commands))
+
+    # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
+    # we a) want to run these before adding constraints and b) `run_upgrade` is
+    # not run on empty databases.
+    insert_sql = """
+    INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
+    """
+
+    logger.info("Converting existing ignore lists")
+    cur.execute(
+        "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
+    )
+    for user_id, content_json in cur.fetchall():
+        content = db_to_json(content_json)
+
+        # The content should be the form of a dictionary with a key
+        # "ignored_users" pointing to a dictionary with keys of ignored users.
+        #
+        # { "ignored_users": "@someone:example.org": {} }
+        ignored_users = content.get("ignored_users", {})
+        if isinstance(ignored_users, dict) and ignored_users:
+            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+
+    # Add indexes after inserting data for efficiency.
+    logger.info("Adding constraints to ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_constraints_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_create_commands = """
+-- Users which are ignored when calculating push notifications. This data is
+-- denormalized from account data.
+CREATE TABLE IF NOT EXISTS ignored_users(
+    ignorer_user_id TEXT NOT NULL,  -- The user ID of the user who is ignoring another user. (This is a local user.)
+    ignored_user_id TEXT NOT NULL  -- The user ID of the user who is being ignored. (This is a local or remote user.)
+);
+"""
+
+_constraints_commands = """
+CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id);
+
+-- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user.
+CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id);
+"""
diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
new file mode 100644
index 0000000000..d781a92fec
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
new file mode 100644
index 0000000000..45a845a3a5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
+
+-- We need to take the max across both device_inbox and device_federation_outbox
+-- tables as they share the ID generator
+SELECT setval('device_inbox_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
+    )
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+  event_id TEXT PRIMARY KEY,
+  chain_id BIGINT NOT NULL,
+  sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+  origin_chain_id BIGINT NOT NULL,
+  origin_sequence_number BIGINT NOT NULL,
+
+  target_chain_id BIGINT NOT NULL,
+  target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+  event_id TEXT PRIMARY KEY,
+  room_id TEXT NOT NULL,
+  type TEXT NOT NULL,
+  state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
new file mode 100644
index 0000000000..64ab696cfe
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS account_data_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
new file mode 100644
index 0000000000..fb71b360a0
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS cache_invalidation_stream;
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 9f120d3cb6..74da9c49f2 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -255,16 +255,6 @@ class TagsStore(TagsWorkerStore):
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
         )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
-        )
-        txn.execute(update_max_id_sql, (next_id, next_id))
-
         update_sql = (
             "UPDATE room_tags_revisions"
             " SET stream_id = ?"
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f91a2eae7a..566ea19bae 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -35,10 +35,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-# XXX: If you're about to bump this to 59 (or higher) please create an update
-# that drops the unused `cache_invalidation_stream` table, as per #7436!
-# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
-SCHEMA_VERSION = 58
+SCHEMA_VERSION = 59
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -375,7 +372,16 @@ def _upgrade_existing_database(
     specific_engine_extensions = (".sqlite", ".postgres")
 
     for v in range(start_ver, SCHEMA_VERSION + 1):
-        logger.info("Applying schema deltas for v%d", v)
+        if not is_worker:
+            logger.info("Applying schema deltas for v%d", v)
+
+            cur.execute("DELETE FROM schema_version")
+            cur.execute(
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
+                (v, True),
+            )
+        else:
+            logger.info("Checking schema deltas for v%d", v)
 
         # We need to search both the global and per data store schema
         # directories for schema updates.
@@ -489,12 +495,6 @@ def _upgrade_existing_database(
                 (v, relative_path),
             )
 
-            cur.execute("DELETE FROM schema_version")
-            cur.execute(
-                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
-                (v, True),
-            )
-
     logger.info("Schema now up to date")
 
 
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 601305487c..1adc92eb90 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]):
             keylen=keylen,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d)) if iterable else None,
+            size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
         )  # type: LruCache[KT, VT]
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..f7b4857a84 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import heapq
 from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    Mapping,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from synapse.types import Collection
 
 T = TypeVar("T")
 
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
     If the input is empty, no chunks are returned.
     """
     return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+    """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+    For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+    """
+
+    # This is implemented by Kahn's algorithm.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph = {}  # type: Dict[T, Set[T]]
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in edges:
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+    heapq.heapify(zero_degree)
+
+    while zero_degree:
+        node = heapq.heappop(zero_degree)
+        yield node
+
+        for edge in reverse_graph[node]:
+            if edge in degree_map:
+                degree_map[edge] -= 1
+                if degree_map[edge] == 0:
+                    heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..f4de6b9f54 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,16 @@ class Measure:
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        parent_context = current_context()
+        curr_context = current_context()
+        if not curr_context:
+            logger.warning(
+                "Starting metrics collection %r from sentinel context: metrics will be lost",
+                name,
+            )
+            parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )